Changing forward, set_errors and backward function to private in Network class
This commit is contained in:
parent
7af273847e
commit
756056f55c
3 changed files with 22 additions and 14 deletions
|
@ -11,8 +11,8 @@ Neuron::Neuron(int prev_layer_size)
|
||||||
{
|
{
|
||||||
for(int i(1) ; i<=prev_layer_size ; i++)
|
for(int i(1) ; i<=prev_layer_size ; i++)
|
||||||
{
|
{
|
||||||
//weights.push_front(Tools::get_random(0.0, 1.0));
|
weights.push_front(Tools::get_random(0.0, 1.0));
|
||||||
weights.push_front(1.0);
|
//weights.push_front(1.0);
|
||||||
}
|
}
|
||||||
bias = 0.1;
|
bias = 0.1;
|
||||||
weighted_sum = 0.0;
|
weighted_sum = 0.0;
|
||||||
|
@ -140,6 +140,17 @@ Network::Network(const std::vector<int> &n_neurons, Activ h_activ, Activ o_activ
|
||||||
o_activ = o_activ;
|
o_activ = o_activ;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Network::train(const std::vector<float> &input, const std::vector<float> &target, float learning_rate, int n_episodes)
|
||||||
|
{
|
||||||
|
for(int episode=1;episode<=n_episodes;episode++)
|
||||||
|
{
|
||||||
|
forward(input, target);
|
||||||
|
set_errors(target);
|
||||||
|
backward(learning_rate);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool Network::forward(const std::vector<float> &input, const std::vector<float> &target)
|
bool Network::forward(const std::vector<float> &input, const std::vector<float> &target)
|
||||||
{
|
{
|
||||||
int layer_counter = 0;
|
int layer_counter = 0;
|
||||||
|
@ -174,7 +185,7 @@ bool Network::forward(const std::vector<float> &input, const std::vector<float>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
set_errors(target);
|
//set_errors(target);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
13
annclasses.h
13
annclasses.h
|
@ -40,22 +40,23 @@ public:
|
||||||
Network(int n_layers, int n_neurons);
|
Network(int n_layers, int n_neurons);
|
||||||
Network(const std::vector<int> &n_neurons, Activ h_activ=RELU, Activ o_activ=SIGMOID);
|
Network(const std::vector<int> &n_neurons, Activ h_activ=RELU, Activ o_activ=SIGMOID);
|
||||||
|
|
||||||
|
bool train(const std::vector<float> &input, const std::vector<float> &target, float learning_rate, int n_episodes);
|
||||||
|
|
||||||
float predict(const std::vector<float> &input, bool as_raw=true);
|
float predict(const std::vector<float> &input, bool as_raw=true);
|
||||||
void print();
|
void print();
|
||||||
|
|
||||||
//to be deleted
|
//to be deleted
|
||||||
bool forward(const std::vector<float> &input, const std::vector<float> &target);
|
//bool forward(const std::vector<float> &input, const std::vector<float> &target);
|
||||||
bool set_errors(const std::vector<float> &target);
|
//bool set_errors(const std::vector<float> &target);
|
||||||
bool backward(float learning_rate);
|
//bool backward(float learning_rate);
|
||||||
private:
|
private:
|
||||||
std::list<std::forward_list<Neuron>> layers;
|
std::list<std::forward_list<Neuron>> layers;
|
||||||
Activ h_activ;
|
Activ h_activ;
|
||||||
Activ o_activ;
|
Activ o_activ;
|
||||||
|
|
||||||
//bool forward(const std::vector<float> &input, const std::vector<float> &target);
|
bool forward(const std::vector<float> &input, const std::vector<float> &target);
|
||||||
//bool set_errors(const std::vector<float> &target);
|
bool set_errors(const std::vector<float> &target);
|
||||||
//bool backward(float learning_rate);
|
bool backward(float learning_rate);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
6
main.cpp
6
main.cpp
|
@ -14,11 +14,7 @@ int main(int argc, char *argv[])
|
||||||
Network network(15, 3);
|
Network network(15, 3);
|
||||||
network.print();
|
network.print();
|
||||||
cout << endl << endl;
|
cout << endl << endl;
|
||||||
for(int episode=1;episode<=100000;episode++)
|
network.train({1.0,1.0,1.0}, {1.0,2.0,3.0}, 0.001, 100000);
|
||||||
{
|
|
||||||
network.forward({1.0,1.0,1.0}, {1.0,2.0,3.0});
|
|
||||||
network.backward(0.001);
|
|
||||||
}
|
|
||||||
//network.print();
|
//network.print();
|
||||||
cout << endl << endl;
|
cout << endl << endl;
|
||||||
network.print();
|
network.print();
|
||||||
|
|
Loading…
Reference in a new issue