Updating train function in Network class

This commit is contained in:
chabisik 2022-02-07 02:40:26 +01:00
parent 8fd19c2d8f
commit d499f8a29a

View file

@ -154,43 +154,82 @@ bool Network::train(const vector<vector<float>> &inputs, const vector<vector<flo
{
if(inputs.size() == targets.size())
{
vector<vector<float>> all_activated_outputs(get_neurons_number());
vector<vector<float>> all_derrors(get_neurons_number()-inputs.at(0).size());
bool is_constructed = false;
//vector<vector<float>> all_activated_outputs(get_neurons_number());
//vector<vector<float>> all_derrors(get_neurons_number()-inputs.at(0).size());
for(int episode=1 ; episode<=n_episodes ; episode++)
{
for(int index(0) ; index<inputs.size() ; index++)//batch_size not yet used
for(int batch_index(0) ; batch_index<inputs.size() ; batch_index+=batch_size)
{
forward(inputs.at(index), targets.at(index));
set_errors(targets.at(index));
int layer_counter = 0;
vector<vector<float>> all_activated_outputs(get_neurons_number());
vector<vector<float>> all_derrors(get_neurons_number()-inputs.at(0).size());
int neurons_counter1 = 0;
int neurons_counter2 = 0;
for(list<forward_list<Neuron>>::iterator current_layer(layers.begin()) ; current_layer!=layers.end() ; ++current_layer)
for(int index(batch_index) ; index<inputs.size() && index<batch_index+batch_size ; index++)//batch_size not yet used
{
forward(inputs.at(index), targets.at(index));
set_errors(targets.at(index));
int layer_counter = 0;
//int neurons_counter1 = 0;
//int neurons_counter2 = 0;
for(list<forward_list<Neuron>>::iterator cur_layer(layers.begin()) ; cur_layer!=layers.end() ; ++cur_layer)
{
layer_counter++;
if(layer_counter==1)
{
for(forward_list<Neuron>::iterator cur_neuron(cur_layer->begin()) ; cur_neuron!=cur_layer->end() ; ++cur_neuron)
{
all_activated_outputs.at(neurons_counter1).push_back( cur_neuron->get_activated_output() );
neurons_counter1++;
}
}else
{
for(forward_list<Neuron>::iterator cur_neuron(cur_layer->begin()) ; cur_neuron!=cur_layer->end() ; ++cur_neuron)
{
all_activated_outputs.at(neurons_counter1).push_back( cur_neuron->get_activated_output() );
neurons_counter1++;
all_derrors.at(neurons_counter2).push_back( cur_neuron->get_derror() );
neurons_counter2++;
}
}
}
}
int layer_counter = 0;
neurons_counter1 = 0;
neurons_counter2 = 0;
for(list<forward_list<Neuron>>::iterator cur_layer(layers.begin()) ; cur_layer!=layers.end() ; ++cur_layer)
{
layer_counter++;
if(layer_counter==1)
{
for(forward_list<Neuron>::iterator current_neuron(current_layer->begin()) ; current_neuron!=current_layer->end() ; ++current_neuron)
for(forward_list<Neuron>::iterator cur_neuron(cur_layer->begin()) ; cur_neuron!=cur_layer->end() ; ++cur_neuron)
{
all_activated_outputs.at(neurons_counter1).push_back( current_neuron->get_activated_output() );
cur_neuron->set_activated_output( accumulate(all_activated_outputs.at(neurons_counter1).begin(),
all_activated_outputs.at(neurons_counter1).end(),0)/all_activated_outputs.at(neurons_counter1).size() );
//all_activated_outputs.at(neurons_counter1).push_back( cur_neuron->get_activated_output() );
neurons_counter1++;
}
}else
{
for(forward_list<Neuron>::iterator current_neuron(current_layer->begin()) ; current_neuron!=current_layer->end() ; ++current_neuron)
for(forward_list<Neuron>::iterator cur_neuron(cur_layer->begin()) ; cur_neuron!=cur_layer->end() ; ++cur_neuron)
{
all_activated_outputs.at(neurons_counter1).push_back( current_neuron->get_activated_output() );
cur_neuron->set_activated_output( accumulate(all_activated_outputs.at(neurons_counter1).begin(),
all_activated_outputs.at(neurons_counter1).end(),0)/all_activated_outputs.at(neurons_counter1).size() );
//all_activated_outputs.at(neurons_counter1).push_back( cur_neuron->get_activated_output() );
neurons_counter1++;
all_derrors.at(neurons_counter2).push_back( current_neuron->get_derror() );
cur_neuron->set_derror( accumulate(all_derrors.at(neurons_counter2).begin(),
all_derrors.at(neurons_counter2).end(),0)/all_derrors.at(neurons_counter2).size() );
//all_derrors.at(neurons_counter2).push_back( cur_neuron->get_derror() );
neurons_counter2++;
}
}
}
backward(learning_rate);
}
backward(learning_rate);
//backward(learning_rate);
}
}else
{