defgenerate_policy(width, height,Q,goal): dic={0:'D',1:'U',2:'R',3:'L'} for i inrange(height): s = "" for j inrange(width): if (j,i) == goal: s += "S" else: s += dic[np.argmax(Q[j,i])] print(s) print()
import numpy as np from env import Grid import matplotlib.pyplot as plt grid = Grid() width = grid.width height = grid.height Q= np.zeros((width,height,4)) defget_action(q_table,eps,x,y): if np.random.rand() < eps: a = np.random.randint(4) else: a = np.argmax(q_table[x,y]) return a
EPOCH = 10000 MIN_EPS = 0.05 GAMMA = 0.95 LR = 0.1 rets = [] for i inrange(EPOCH): eps = 1 - (1 - MIN_EPS)/EPOCH * i pos = grid.reset() a = get_action(Q,eps,*pos) ret=0 # print("pos",pos) whileTrue: # print("a",a) r,pos_,flag = grid.step(a) # print("pos_",pos_) ret += r a_ = get_action(Q,eps,*pos_) Q[pos[0],pos[1],a] += LR*(r + GAMMA * Q[pos_[0],pos_[1],a_] - Q[pos[0],pos[1],a]) pos = pos_ a = a_
if flag: break if i % (EPOCH/200) == 0: rets.append(ret)
defgenerate_policy(width, height,Q,goal): dic={0:'D',1:'U',2:'R',3:'L'} for i inrange(height): s = "" for j inrange(width): if (j,i) == goal: s += "S" else: s += dic[np.argmax(Q[j,i])] print(s) print()
if flag: break if i % (EPOCH/200) == 0: rets.append(ret)
defgenerate_policy(width, height,Q,goal): dic={0:'D',1:'U',2:'R',3:'L'} for i inrange(height): s = "" for j inrange(width): if (j,i) == goal: s += "S" else: s += dic[np.argmax(Q[j,i])] print(s) print()
goal=grid.goal generate_policy(width,height,Q,goal) fig = plt.figure() x = np.arange(0,EPOCH,EPOCH/200) plt.plot(x,rets) plt.show()