No Description
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

annclasses.h 2.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #ifndef ANNCLASSES_H
  2. #define ANNCLASSES_H
  3. #include <forward_list>
  4. #include <list>
  5. #include <vector>
  6. enum Activ
  7. {
  8. RELU, TANH, SIGMOID, LINEAR, SOFTMAX
  9. };
  10. class Neuron
  11. {
  12. public:
  13. Neuron(int prev_layer_size); //prev_layer_size = number of weights
  14. void set_bias(float value);
  15. float get_bias();
  16. void set_nth_weight(int n, float value);
  17. float get_nth_weight(int n);
  18. float get_weighted_sum();
  19. void set_activated_output(float value);
  20. float get_activated_output();
  21. void set_derror(float value);
  22. float get_derror();
  23. void activate(std::forward_list<Neuron>::iterator &prev_layer_it, Activ activ_function=LINEAR);
  24. private:
  25. std::forward_list<float> weights;
  26. float bias;
  27. float weighted_sum;
  28. float activated_output;
  29. float derror;
  30. };
  31. class Network
  32. {
  33. public:
  34. Network(int n_layers, int n_neurons);
  35. Network(const std::vector<int> &n_neurons, Activ h_activ=RELU, Activ o_activ=SIGMOID);
  36. int get_neurons_number();
  37. bool train(const std::vector<std::vector<float>> &inputs, const std::vector<std::vector<float>> &targets, float learning_rate=0.001, int n_episodes=30, int batch_size=32);
  38. std::vector<float> predict(const std::vector<std::vector<float>> &inputs, bool as_raw=true);
  39. void print();
  40. //to be deleted
  41. //bool forward(const std::vector<float> &input, const std::vector<float> &target);
  42. //bool set_errors(const std::vector<float> &target);
  43. //bool backward(float learning_rate);
  44. private:
  45. std::list<std::forward_list<Neuron>> layers;
  46. int neurons_number;
  47. Activ h_activ;
  48. Activ o_activ;
  49. bool forward(const std::vector<float> &input, const std::vector<float> &target);
  50. bool set_errors(const std::vector<float> &target);
  51. bool backward(float learning_rate);
  52. };
  53. class Tools
  54. {
  55. public:
  56. static void activate_randomness();
  57. static float get_random(float mini, float maxi);
  58. //Activation functions and their derivatives
  59. static float activation_function(Activ activ, float value);
  60. static float activation_function_derivative(Activ activ, float value);
  61. //float activation_function(Activ activ, float value);
  62. //float activation_function_derivative(Activ activ, float value);
  63. private:
  64. float relu(float value);
  65. float sigmoid(float value);
  66. float relu_derivative(float value);
  67. float sigmoid_derivative(float value);
  68. float tanh_derivative(float value);
  69. };
  70. #endif