summaryrefslogtreecommitdiff
path: root/2017/ev3-controller/line_follower_Q.py
blob: 812ab5eda99752036bf0131236fe8ebb10916605 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright 2017 Amra Omanović, Nejka Bolčič, Magda Nowak-Trzos
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

import numpy as np
from EV3Robot import *
import csv
gamma = 0.8
alpha = 1.

robot = Robot()
robot.connect_motor( 'left' )
robot.connect_motor( 'right' )
robot.connect_sensor( 'color' )

r = np.array([[1, -10, -1],
              [-100, 10, -1],
              [-100, -10, 100]]).astype("float32")

#squre the difference

q = np.random.rand(3,3)

def update_q(state, next_state, action):
    r_sa = r[state, action] #reward acc to state and action
    q_sa = q[state, action] # q value acc to state and action
    new_q = q_sa + alpha * (r_sa + gamma * max(q[next_state, :]) - q_sa)
    q[state, action] = new_q
    # renormalize row to be between 0 and 1
    rn = q[state][q[state] > 0] / np.sum(q[state][q[state] > 0])
    q[state][q[state] > 0] = rn
    return r[state, action]

def bgw(isee, follow_color=50, grey_zone=25):
    if isee < follow_color - grey_zone: # BLACK
        return 0
    elif isee > follow_color + grey_zone: # WHITE
        return 1
    else: # move forward if in the grey zone
        return 2

def get_state():
    isee = robot.color_sensor_measure('reflected_light_intensity')
    color = bgw(isee)
    return color

def do_action(action, speed):
    if action == 0:
        robot.move(0,speed)
    elif action == 1:
        robot.move(speed, 0)
    elif action == 2:
        robot.move(speed, speed)


def run(speed):
    data_file = open('training_data.csv', 'wb')
    while(1):
        #check the state
        state_1 = get_state()
        #action taken according to maximum value of q table in color column
        action = np.argmax(q,axis=0)[state_1]
        # do the action
        do_action(action, speed)
        state_2 = get_state()
        update_q(state_1, state_2, action)

        action2 = np.argmax(q,axis=0)[state_2]
        do_action(action2,speed)
        state_3 = get_state()
        update_q(state_2, state_3, action2)

        collect_data(data_file, state_1, state_2, action2)


def collect_data(data_file, prev_state, current_state, action):
    writer = csv.writer(data_file, delimiter=',')
    writer.writerow([prev_state, current_state, action])



run(15)