# 根据当前状态和采取的行为计算下一个状态id以及得到的即时奖励
def nextState(s, a):
next_state = s
if (s%4 == 0 and a == "w") or (s<4 and a == "n") or \
((s+1)%4 == 0 and a == "e") or (s > 11 and a == "s"):
pass
else:
ds = ds_actions[a]
next_state = s + ds
return next_state
更新规则
# update the value of state s
def updateValue(s):
sucessors = getSuccessors(s)
newValue = 0 # values[s]
num = 4 # len(successors)
reward = rewardOf(s)
for next_state in sucessors:
newValue += 1.00/num * (reward + gamma * values[next_state])
return newValue
def policy_evaluate(self, grid_mdp):
for i in range(1000):
delta = 0.0
for state in grid_mdp.states:
if state in grid_mdp.terminal_states: continue
action = self.pi[state]
t, s, r = grid_mdp.transform(state, action)
new_r = r + grid_mdp.gamma * self.v[s]
delta += abs(self.v[state] - new_v)
self.v[state] = new_v
if delta < 1e-6:
break
策略改善方法
def policy_improve(self, grid_mdp):
for state in grid_mdp.states:
if state in grid_mdp.terminal_states: continue
a1 = grid_mdp.actions[0]
t, s, r = grid_mdp.transform(state, a1)
v1 = r + grid_mdp.gamma * v[s]
for action in grid_mdp.actions:
t, s, r = grid_mdp.transform(state, action)
if v1 < r + grid_mdp.gamma * v[s]:
a1 = action
v1 = r + grid_mdp.gamma * v[s]
self.pi[state] = a1
值迭代算法
def value_iteration(self, grid_mdp):
for i in range(1000):
delta = 0.0
for state in grid_mdp.states:
if state in grid_mdp.terminal_states: continue
a1 = grid_mdp.actions[0]
t, s, r = grid_mdp.transform(state, a1)
v1 = r + grid_mdp.gamma * v[s]
for action in grid_mdp.actions:
t, s, r = grid_mdp.transform(state, action)
if v1 < r + grid_mdp.gamma * v[s]:
a1 = action
v1 = r + grid_mdp.gamma * v[s]
delta += abs(self.v[state] - new_v)
self.v[state] = v1
self.pi[state] = a1
if delta < 1e-6:
break