308 lines
10 KiB
Java
308 lines
10 KiB
Java
import java.text.DecimalFormat;
|
|
import java.text.NumberFormat;
|
|
import java.util.Hashtable;
|
|
import java.util.Locale;
|
|
import java.util.Random;
|
|
|
|
// TDDC17 - Lab 5
|
|
// Part 2: Q Learning Controller
|
|
// Last modification: 2020-10-06
|
|
|
|
/* GUI commands :
|
|
* M : disable GUI, much faster
|
|
* E : toggle exploration
|
|
* V : Speed up learning
|
|
* P : turn off controller
|
|
* O : turn on controller
|
|
* W, A, S, D : manual commands
|
|
*/
|
|
|
|
/* TODO:
|
|
* -Define state and reward functions (StateAndReward.java) suitable for your problem --> done.
|
|
* -Define actions --> done.
|
|
* -Implement missing parts of Q-learning --> done.
|
|
* -Tune state and reward function, and parameters below if the result is not satisfactory */
|
|
|
|
public class QLearningController extends Controller {
|
|
// Change this value for part 3 (State & reward hover)
|
|
boolean HOVER_MODE = true;
|
|
|
|
/* These are the agents senses (inputs) */
|
|
DoubleFeature x; /* Positions */
|
|
DoubleFeature y;
|
|
DoubleFeature vx; /* Velocities */
|
|
DoubleFeature vy;
|
|
DoubleFeature angle; /* Angle */
|
|
|
|
/* These are the agents actuators (outputs)*/
|
|
RocketEngine leftEngine;
|
|
RocketEngine middleEngine;
|
|
RocketEngine rightEngine;
|
|
|
|
final static int NUM_ACTIONS = 7; /* The takeAction function must be changed if this is modified */
|
|
|
|
/* Keep track of the previous state and action */
|
|
String previous_state = null;
|
|
double previous_vx = 0;
|
|
double previous_vy = 0;
|
|
double previous_angle = 0;
|
|
int previous_action = 0;
|
|
|
|
/* The tables used by Q-learning */
|
|
Hashtable<String, Double> Qtable = new Hashtable<String, Double>(); /* Contains the Q-values - the state-action utilities */
|
|
Hashtable<String, Integer> Ntable = new Hashtable<String, Integer>(); /* Keeps track of how many times each state-action combination has been used */
|
|
|
|
/* PARAMETERS OF THE LEARNING ALGORITHM - THESE MAY BE TUNED BUT THE DEFAULT VALUES OFTEN WORK REASONABLY WELL */
|
|
static final double GAMMA_DISCOUNT_FACTOR = 0.95; /* Must be < 1, small values make it very greedy */
|
|
static final double LEARNING_RATE_CONSTANT = 10; /* See alpha(), lower values are good for quick results in large and deterministic state spaces */
|
|
double explore_chance = 0.5; /* The exploration chance during the exploration phase */
|
|
final static int REPEAT_ACTION_MAX = 30; /* Repeat selected action at most this many times trying reach a new state, without a max it could loop forever if the action cannot lead to a new state */
|
|
|
|
/* Some internal counters */
|
|
int iteration = 0; /* Keeps track of how many iterations the agent has run */
|
|
int action_counter = 0; /* Keeps track of how many times we have repeated the current action */
|
|
int print_counter = 0; /* Makes printouts less spammy */
|
|
|
|
/* These are just internal helper variables, you can ignore these */
|
|
boolean paused = false;
|
|
boolean explore = true; /* Will always do exploration by default */
|
|
|
|
DecimalFormat df = (DecimalFormat) NumberFormat.getNumberInstance(Locale.US);
|
|
public SpringObject object;
|
|
ComposedSpringObject cso;
|
|
long lastPressedExplore = 0;
|
|
|
|
public void init() {
|
|
cso = (ComposedSpringObject) object;
|
|
x = (DoubleFeature) cso.getObjectById("x");
|
|
y = (DoubleFeature) cso.getObjectById("y");
|
|
vx = (DoubleFeature) cso.getObjectById("vx");
|
|
vy = (DoubleFeature) cso.getObjectById("vy");
|
|
angle = (DoubleFeature) cso.getObjectById("angle");
|
|
|
|
previous_vy = vy.getValue();
|
|
previous_vx = vx.getValue();
|
|
previous_angle = angle.getValue();
|
|
|
|
leftEngine = (RocketEngine) cso.getObjectById("rocket_engine_left");
|
|
rightEngine = (RocketEngine) cso.getObjectById("rocket_engine_right");
|
|
middleEngine = (RocketEngine) cso.getObjectById("rocket_engine_middle");
|
|
}
|
|
|
|
|
|
/* Turn off all rockets */
|
|
void resetRockets() {
|
|
leftEngine.setBursting(false);
|
|
rightEngine.setBursting(false);
|
|
middleEngine.setBursting(false);
|
|
}
|
|
|
|
/* Performs the chosen action */
|
|
void performAction(int action) {
|
|
/* Fire zeh rockets! */
|
|
resetRockets();
|
|
switch (action) {
|
|
// Do nothing
|
|
case 0: break;
|
|
|
|
case 1: leftEngine.setBursting(true);
|
|
break;
|
|
|
|
case 2: rightEngine.setBursting(true);
|
|
break;
|
|
|
|
case 3: middleEngine.setBursting(true);
|
|
break;
|
|
|
|
case 4: leftEngine.setBursting(true);
|
|
middleEngine.setBursting(true);
|
|
break;
|
|
|
|
case 5: rightEngine.setBursting(true);
|
|
middleEngine.setBursting(true);
|
|
break;
|
|
|
|
case 6: middleEngine.setBursting(true);
|
|
leftEngine.setBursting(true);
|
|
rightEngine.setBursting(true);
|
|
break;
|
|
|
|
case 7: rightEngine.setBursting(true);
|
|
leftEngine.setBursting(true);
|
|
break;
|
|
}
|
|
}
|
|
|
|
/* Main decision loop. Called every iteration by the simulator */
|
|
public void tick(int currentTime) {
|
|
iteration++;
|
|
|
|
if (!paused) {
|
|
String new_state;
|
|
double previous_reward;
|
|
|
|
// You can choose the value of HOVER_MODE line 19 (for part 3)
|
|
if (!HOVER_MODE) {
|
|
new_state = StateAndReward.getStateAngle(angle.getValue(), vx.getValue(), vy.getValue());
|
|
previous_reward = StateAndReward.getRewardAngle(previous_angle, previous_vx, previous_vy);
|
|
}
|
|
else {
|
|
new_state = StateAndReward.getStateHover(angle.getValue(), vx.getValue(), vy.getValue());
|
|
previous_reward = StateAndReward.getRewardHover(previous_angle, previous_vx, previous_vy);
|
|
}
|
|
|
|
/* Repeat the chosen action for a while, hoping to reach a new state. This is a trick to speed up learning on this problem. */
|
|
action_counter++;
|
|
if (new_state.equals(previous_state) && action_counter < REPEAT_ACTION_MAX) {
|
|
return;
|
|
}
|
|
|
|
action_counter = 0;
|
|
|
|
/* The agent is in a new state, do learning and action selection */
|
|
if (previous_state != null) {
|
|
/* Create state-action key */
|
|
String prev_stateaction = previous_state + previous_action;
|
|
|
|
/* Increment state-action counter */
|
|
if (Ntable.get(prev_stateaction) == null) {
|
|
Ntable.put(prev_stateaction, 0);
|
|
}
|
|
Ntable.put(prev_stateaction, Ntable.get(prev_stateaction) + 1);
|
|
|
|
/* Update Q value */
|
|
if (Qtable.get(prev_stateaction) == null) {
|
|
Qtable.put(prev_stateaction, 0.0);
|
|
} else {
|
|
// Here, we use temporal-difference Q-learning equation, as define in
|
|
// AI: A Modern Approach, 4th ed. (Russell, Norvig) p. 802 (22.7):
|
|
// Q(s,a) <-- Q(s,a) + alpha*[R(s,a,s') + gamma*max(a')(Q(s',a')) - Q(s,a)]
|
|
// where: alpha --> learning rate
|
|
// gamma --> discount factor
|
|
|
|
double gamma = GAMMA_DISCOUNT_FACTOR;
|
|
double alpha = alpha(Ntable.get(prev_stateaction));
|
|
|
|
double Q = Qtable.get(prev_stateaction)
|
|
+ alpha*(previous_reward + gamma*getMaxActionQValue(new_state) - Qtable.get(prev_stateaction));
|
|
|
|
Qtable.put(prev_stateaction, Q);
|
|
|
|
// done.
|
|
}
|
|
|
|
int action = selectAction(new_state); /* Make sure you understand how it selects an action */
|
|
|
|
performAction(action);
|
|
|
|
/* Only print every 10th line to reduce spam */
|
|
print_counter++;
|
|
if (print_counter % 1000 == 0) {
|
|
System.out.println("ITERATION: " + iteration + " SENSORS: a=" + df.format(angle.getValue()) + " vx=" + df.format(vx.getValue()) +
|
|
" vy=" + df.format(vy.getValue()) + " P_STATE: " + previous_state + " P_ACTION: " + previous_action +
|
|
" P_REWARD: " + df.format(previous_reward) + " P_QVAL: " + df.format(Qtable.get(prev_stateaction)) + " Tested: "
|
|
+ Ntable.get(prev_stateaction) + " times.");
|
|
}
|
|
|
|
previous_vy = vy.getValue();
|
|
previous_vx = vx.getValue();
|
|
previous_angle = angle.getValue();
|
|
previous_action = action;
|
|
}
|
|
previous_state = new_state;
|
|
}
|
|
|
|
}
|
|
|
|
/* Computes the learning rate parameter alpha based on the number of times the state-action combination has been tested */
|
|
public double alpha(int num_tested) {
|
|
/* Lower learning rate constants means that alpha will become small faster and therefore make the agent behavior converge to
|
|
* to a solution faster, but if the state space is not properly explored at that point the resulting behavior may be poor.
|
|
* If your state-space is really huge you may need to increase it. */
|
|
double alpha = (LEARNING_RATE_CONSTANT/(LEARNING_RATE_CONSTANT + num_tested));
|
|
return alpha;
|
|
}
|
|
|
|
/* Finds the highest Qvalue of any action in the given state */
|
|
public double getMaxActionQValue(String state) {
|
|
double maxQval = Double.NEGATIVE_INFINITY;
|
|
|
|
for (int action = 0; action < NUM_ACTIONS; action++) {
|
|
Double Qval = Qtable.get(state+action);
|
|
if (Qval != null && Qval > maxQval) {
|
|
maxQval = Qval;
|
|
}
|
|
}
|
|
|
|
if (maxQval == Double.NEGATIVE_INFINITY) {
|
|
/* Assign 0 as that corresponds to initializing the Qtable to 0. */
|
|
maxQval = 0;
|
|
}
|
|
return maxQval;
|
|
}
|
|
|
|
/* Selects an action in a state based on the registered Q-values and the exploration chance */
|
|
public int selectAction(String state) {
|
|
Random rand = new Random();
|
|
|
|
int action = 0;
|
|
/* May do exploratory move if in exploration mode */
|
|
if (explore && Math.abs(rand.nextDouble()) < explore_chance) {
|
|
/* Taking random exploration action! */
|
|
action = Math.abs(rand. nextInt()) % NUM_ACTIONS;
|
|
return action;
|
|
}
|
|
|
|
/* Find action with highest Q-val (utility) in given state */
|
|
double maxQval = Double.NEGATIVE_INFINITY;
|
|
for (int i = 0; i < NUM_ACTIONS; i++) {
|
|
String test_pair = state + i; /* Generate a state-action pair for all actions */
|
|
double Qval = 0;
|
|
if (Qtable.get(test_pair) != null) {
|
|
Qval = Qtable.get(test_pair);
|
|
}
|
|
if (Qval > maxQval) {
|
|
maxQval = Qval;
|
|
action = i;
|
|
}
|
|
}
|
|
return action;
|
|
}
|
|
|
|
|
|
|
|
/* The 'E' key will toggle the agents exploration mode. Turn this off to test its behavior */
|
|
public void toggleExplore() {
|
|
/* Make sure we don't toggle it multiple times */
|
|
if (System.currentTimeMillis() - lastPressedExplore < 1000) {
|
|
return;
|
|
}
|
|
if (explore) {
|
|
System.out.println("Turning OFF exploration!");
|
|
explore = false;
|
|
} else {
|
|
System.out.println("Turning ON exploration!");
|
|
explore = true;
|
|
}
|
|
lastPressedExplore = System.currentTimeMillis();
|
|
}
|
|
|
|
/* Keys 1 and 2 can be customized for whatever purpose if you want to */
|
|
public void toggleCustom1() {
|
|
System.out.println("Custom key 1 pressed!");
|
|
}
|
|
|
|
/* Keys 1 and 2 can be customized for whatever purpose if you want to */
|
|
public void toggleCustom2() {
|
|
System.out.println("Custom key 2 pressed!");
|
|
}
|
|
|
|
public void pause() {
|
|
paused = true;
|
|
resetRockets();
|
|
}
|
|
|
|
public void run() {
|
|
paused = false;
|
|
}
|
|
}
|