summaryrefslogtreecommitdiff
path: root/2017/ev3-controller/line_follower_Q.py
diff options
context:
space:
mode:
Diffstat (limited to '2017/ev3-controller/line_follower_Q.py')
-rw-r--r--2017/ev3-controller/line_follower_Q.py80
1 files changed, 80 insertions, 0 deletions
diff --git a/2017/ev3-controller/line_follower_Q.py b/2017/ev3-controller/line_follower_Q.py
new file mode 100644
index 0000000..a1ea427
--- /dev/null
+++ b/2017/ev3-controller/line_follower_Q.py
@@ -0,0 +1,80 @@
+import numpy as np
+from EV3Robot import *
+import csv
+gamma = 0.8
+alpha = 1.
+
+robot = Robot()
+robot.connect_motor( 'left' )
+robot.connect_motor( 'right' )
+robot.connect_sensor( 'color' )
+
+r = np.array([[1, -10, -1],
+ [-100, 10, -1],
+ [-100, -10, 100]]).astype("float32")
+
+#squre the difference
+
+q = np.random.rand(3,3)
+
+def update_q(state, next_state, action):
+ r_sa = r[state, action] #reward acc to state and action
+ q_sa = q[state, action] # q value acc to state and action
+ new_q = q_sa + alpha * (r_sa + gamma * max(q[next_state, :]) - q_sa)
+ q[state, action] = new_q
+ # renormalize row to be between 0 and 1
+ rn = q[state][q[state] > 0] / np.sum(q[state][q[state] > 0])
+ q[state][q[state] > 0] = rn
+ return r[state, action]
+
+def bgw(isee, follow_color=50, grey_zone=25):
+ if isee < follow_color - grey_zone: # BLACK
+ return 0
+ elif isee > follow_color + grey_zone: # WHITE
+ return 1
+ else: # move forward if in the grey zone
+ return 2
+
+def get_state():
+ isee = robot.color_sensor_measure('reflected_light_intensity')
+ color = bgw(isee)
+ return color
+
+def do_action(action, speed):
+ if action == 0:
+ robot.move(0,speed)
+ elif action == 1:
+ robot.move(speed, 0)
+ elif action == 2:
+ robot.move(speed, speed)
+
+
+def run(speed):
+ data_file = open('training_data.csv', 'wb')
+ while(1):
+ #check the state
+ state_1 = get_state()
+ #action taken according to maximum value of q table in color column
+ action = np.argmax(q,axis=0)[state_1]
+ # do the action
+ do_action(action, speed)
+ state_2 = get_state()
+ update_q(state_1, state_2, action)
+
+ action2 = np.argmax(q,axis=0)[state_2]
+ do_action(action2,speed)
+ state_3 = get_state()
+ update_q(state_2, state_3, action2)
+
+ collect_data(data_file, state_1, state_2, action2)
+
+
+def collect_data(data_file, prev_state, current_state, action):
+ writer = csv.writer(data_file, delimiter=',')
+ writer.writerow([prev_state, current_state, action])
+
+
+
+run(15)
+
+