Updating train function in Network class

This commit is contained in:
chabisik 2022-02-07 02:40:26 +01:00
parent 8fd19c2d8f
commit d499f8a29a

View file

@ -154,43 +154,82 @@ bool Network::train(const vector<vector<float>> &inputs, const vector<vector<flo
{ {
if(inputs.size() == targets.size()) if(inputs.size() == targets.size())
{ {
vector<vector<float>> all_activated_outputs(get_neurons_number()); //vector<vector<float>> all_activated_outputs(get_neurons_number());
vector<vector<float>> all_derrors(get_neurons_number()-inputs.at(0).size()); //vector<vector<float>> all_derrors(get_neurons_number()-inputs.at(0).size());
bool is_constructed = false;
for(int episode=1 ; episode<=n_episodes ; episode++) for(int episode=1 ; episode<=n_episodes ; episode++)
{ {
for(int index(0) ; index<inputs.size() ; index++)//batch_size not yet used for(int batch_index(0) ; batch_index<inputs.size() ; batch_index+=batch_size)
{ {
forward(inputs.at(index), targets.at(index)); vector<vector<float>> all_activated_outputs(get_neurons_number());
set_errors(targets.at(index)); vector<vector<float>> all_derrors(get_neurons_number()-inputs.at(0).size());
int layer_counter = 0;
int neurons_counter1 = 0; int neurons_counter1 = 0;
int neurons_counter2 = 0; int neurons_counter2 = 0;
for(list<forward_list<Neuron>>::iterator current_layer(layers.begin()) ; current_layer!=layers.end() ; ++current_layer) for(int index(batch_index) ; index<inputs.size() && index<batch_index+batch_size ; index++)//batch_size not yet used
{
forward(inputs.at(index), targets.at(index));
set_errors(targets.at(index));
int layer_counter = 0;
//int neurons_counter1 = 0;
//int neurons_counter2 = 0;
for(list<forward_list<Neuron>>::iterator cur_layer(layers.begin()) ; cur_layer!=layers.end() ; ++cur_layer)
{
layer_counter++;
if(layer_counter==1)
{
for(forward_list<Neuron>::iterator cur_neuron(cur_layer->begin()) ; cur_neuron!=cur_layer->end() ; ++cur_neuron)
{
all_activated_outputs.at(neurons_counter1).push_back( cur_neuron->get_activated_output() );
neurons_counter1++;
}
}else
{
for(forward_list<Neuron>::iterator cur_neuron(cur_layer->begin()) ; cur_neuron!=cur_layer->end() ; ++cur_neuron)
{
all_activated_outputs.at(neurons_counter1).push_back( cur_neuron->get_activated_output() );
neurons_counter1++;
all_derrors.at(neurons_counter2).push_back( cur_neuron->get_derror() );
neurons_counter2++;
}
}
}
}
int layer_counter = 0;
neurons_counter1 = 0;
neurons_counter2 = 0;
for(list<forward_list<Neuron>>::iterator cur_layer(layers.begin()) ; cur_layer!=layers.end() ; ++cur_layer)
{ {
layer_counter++; layer_counter++;
if(layer_counter==1) if(layer_counter==1)
{ {
for(forward_list<Neuron>::iterator current_neuron(current_layer->begin()) ; current_neuron!=current_layer->end() ; ++current_neuron) for(forward_list<Neuron>::iterator cur_neuron(cur_layer->begin()) ; cur_neuron!=cur_layer->end() ; ++cur_neuron)
{ {
all_activated_outputs.at(neurons_counter1).push_back( current_neuron->get_activated_output() ); cur_neuron->set_activated_output( accumulate(all_activated_outputs.at(neurons_counter1).begin(),
all_activated_outputs.at(neurons_counter1).end(),0)/all_activated_outputs.at(neurons_counter1).size() );
//all_activated_outputs.at(neurons_counter1).push_back( cur_neuron->get_activated_output() );
neurons_counter1++; neurons_counter1++;
} }
}else }else
{ {
for(forward_list<Neuron>::iterator current_neuron(current_layer->begin()) ; current_neuron!=current_layer->end() ; ++current_neuron) for(forward_list<Neuron>::iterator cur_neuron(cur_layer->begin()) ; cur_neuron!=cur_layer->end() ; ++cur_neuron)
{ {
all_activated_outputs.at(neurons_counter1).push_back( current_neuron->get_activated_output() ); cur_neuron->set_activated_output( accumulate(all_activated_outputs.at(neurons_counter1).begin(),
all_activated_outputs.at(neurons_counter1).end(),0)/all_activated_outputs.at(neurons_counter1).size() );
//all_activated_outputs.at(neurons_counter1).push_back( cur_neuron->get_activated_output() );
neurons_counter1++; neurons_counter1++;
all_derrors.at(neurons_counter2).push_back( current_neuron->get_derror() ); cur_neuron->set_derror( accumulate(all_derrors.at(neurons_counter2).begin(),
all_derrors.at(neurons_counter2).end(),0)/all_derrors.at(neurons_counter2).size() );
//all_derrors.at(neurons_counter2).push_back( cur_neuron->get_derror() );
neurons_counter2++; neurons_counter2++;
} }
} }
} }
backward(learning_rate);
} }
backward(learning_rate); //backward(learning_rate);
} }
}else }else
{ {