cppbasedann/annclasses.h
2022-01-18 13:03:53 +01:00

80 lines
No EOL
2.1 KiB
C++
Executable file

#ifndef MYCLASSES_H
#define MYCLASSES_H
#include <forward_list>
#include <list>
#include <vector>
enum Activ
{
RELU, TANH, SIGMOID, LINEAR, SOFTMAX
};
class Neuron
{
public:
Neuron(int prev_layer_size); //prev_layer_size = number of weights
void set_bias(float value);
float get_bias();
void set_nth_weight(int n, float value);
float get_nth_weight(int n);
float get_weighted_sum();
void set_activated_output(float value);
float get_activated_output();
void set_derror(float value);
float get_derror();
void activate(std::forward_list<Neuron>::iterator &prev_layer_it, Activ activ_function=LINEAR);
private:
std::forward_list<float> weights;
float bias;
float weighted_sum;
float activated_output;
float derror;
};
class Network
{
public:
Network(int n_layers, int n_neurons);
Network(const std::vector<int> &n_neurons, Activ h_activ=RELU, Activ o_activ=SIGMOID);
float predict(const std::vector<float> &input, bool as_raw=true);
void print();
//to be deleted
bool forward(const std::vector<float> &input, const std::vector<float> &target);
bool set_errors(const std::vector<float> &target);
bool backward(float learning_rate);
private:
std::list<std::forward_list<Neuron>> layers;
Activ h_activ;
Activ o_activ;
//bool forward(const std::vector<float> &input, const std::vector<float> &target);
//bool set_errors(const std::vector<float> &target);
//bool backward(float learning_rate);
};
class Tools
{
public:
static void activate_randomness();
static float get_random(float mini, float maxi);
//Activation functions and their derivatives
static float activation_function(Activ activ, float value);
static float activation_function_derivative(Activ activ, float value);
//float activation_function(Activ activ, float value);
//float activation_function_derivative(Activ activ, float value);
private:
float relu(float value);
float sigmoid(float value);
float relu_derivative(float value);
float sigmoid_derivative(float value);
float tanh_derivative(float value);
};
#endif