cppbasedann/annclasses.h
2022-02-07 00:57:27 +01:00

84 lines
No EOL
2.4 KiB
C++
Executable file

#ifndef ANNCLASSES_H
#define ANNCLASSES_H
#include <forward_list>
#include <list>
#include <vector>
enum Activ
{
RELU, TANH, SIGMOID, LINEAR, SOFTMAX
};
class Neuron
{
public:
Neuron(int prev_layer_size); //prev_layer_size = number of weights
void set_bias(float value);
float get_bias();
void set_nth_weight(int n, float value);
float get_nth_weight(int n);
float get_weighted_sum();
void set_activated_output(float value);
float get_activated_output();
void set_derror(float value);
float get_derror();
void activate(std::forward_list<Neuron>::iterator &prev_layer_it, Activ activ_function=LINEAR);
private:
std::forward_list<float> weights;
float bias;
float weighted_sum;
float activated_output;
float derror;
};
class Network
{
public:
Network(int n_layers, int n_neurons);
Network(const std::vector<int> &n_neurons, Activ h_activ=RELU, Activ o_activ=SIGMOID);
int get_neurons_number();
bool train(const std::vector<std::vector<float>> &inputs, const std::vector<std::vector<float>> &targets, float learning_rate=0.001, int n_episodes=30, int batch_size=32);
std::vector<float> predict(const std::vector<std::vector<float>> &inputs, bool as_raw=true);
void print();
//to be deleted
//bool forward(const std::vector<float> &input, const std::vector<float> &target);
//bool set_errors(const std::vector<float> &target);
//bool backward(float learning_rate);
private:
std::list<std::forward_list<Neuron>> layers;
int neurons_number;
Activ h_activ;
Activ o_activ;
bool forward(const std::vector<float> &input, const std::vector<float> &target);
bool set_errors(const std::vector<float> &target);
bool backward(float learning_rate);
};
class Tools
{
public:
static void activate_randomness();
static float get_random(float mini, float maxi);
//Activation functions and their derivatives
static float activation_function(Activ activ, float value);
static float activation_function_derivative(Activ activ, float value);
//float activation_function(Activ activ, float value);
//float activation_function_derivative(Activ activ, float value);
private:
float relu(float value);
float sigmoid(float value);
float relu_derivative(float value);
float sigmoid_derivative(float value);
float tanh_derivative(float value);
};
#endif