Changing forward, set_errors and backward function to private in Network class

This commit is contained in:
chabisik 2022-01-18 13:27:55 +01:00
parent 7af273847e
commit 756056f55c
3 changed files with 22 additions and 14 deletions

View file

@ -11,8 +11,8 @@ Neuron::Neuron(int prev_layer_size)
{
for(int i(1) ; i<=prev_layer_size ; i++)
{
//weights.push_front(Tools::get_random(0.0, 1.0));
weights.push_front(1.0);
weights.push_front(Tools::get_random(0.0, 1.0));
//weights.push_front(1.0);
}
bias = 0.1;
weighted_sum = 0.0;
@ -140,6 +140,17 @@ Network::Network(const std::vector<int> &n_neurons, Activ h_activ, Activ o_activ
o_activ = o_activ;
}
bool Network::train(const std::vector<float> &input, const std::vector<float> &target, float learning_rate, int n_episodes)
{
for(int episode=1;episode<=n_episodes;episode++)
{
forward(input, target);
set_errors(target);
backward(learning_rate);
}
return true;
}
bool Network::forward(const std::vector<float> &input, const std::vector<float> &target)
{
int layer_counter = 0;
@ -174,7 +185,7 @@ bool Network::forward(const std::vector<float> &input, const std::vector<float>
}
}
}
set_errors(target);
//set_errors(target);
return true;
}

View file

@ -40,22 +40,23 @@ public:
Network(int n_layers, int n_neurons);
Network(const std::vector<int> &n_neurons, Activ h_activ=RELU, Activ o_activ=SIGMOID);
bool train(const std::vector<float> &input, const std::vector<float> &target, float learning_rate, int n_episodes);
float predict(const std::vector<float> &input, bool as_raw=true);
void print();
//to be deleted
bool forward(const std::vector<float> &input, const std::vector<float> &target);
bool set_errors(const std::vector<float> &target);
bool backward(float learning_rate);
//bool forward(const std::vector<float> &input, const std::vector<float> &target);
//bool set_errors(const std::vector<float> &target);
//bool backward(float learning_rate);
private:
std::list<std::forward_list<Neuron>> layers;
Activ h_activ;
Activ o_activ;
//bool forward(const std::vector<float> &input, const std::vector<float> &target);
//bool set_errors(const std::vector<float> &target);
//bool backward(float learning_rate);
bool forward(const std::vector<float> &input, const std::vector<float> &target);
bool set_errors(const std::vector<float> &target);
bool backward(float learning_rate);
};

View file

@ -14,11 +14,7 @@ int main(int argc, char *argv[])
Network network(15, 3);
network.print();
cout << endl << endl;
for(int episode=1;episode<=100000;episode++)
{
network.forward({1.0,1.0,1.0}, {1.0,2.0,3.0});
network.backward(0.001);
}
network.train({1.0,1.0,1.0}, {1.0,2.0,3.0}, 0.001, 100000);
//network.print();
cout << endl << endl;
network.print();