# 无痛的增强学习入门： 策略迭代

| 作者 冯超 7 他的粉丝 发布于 2017年8月8日. 估计阅读时间: 16 分钟 | Google、Facebook、Pinterest、阿里、腾讯 等顶尖技术团队的上百个可供参考的架构实例！

## 3 策略迭代法

### 3.1 策略迭代法

def simple_eval():
env = Snake(0, [3,6])
agent = TableAgent(env.state_transition_table(), env.reward_table())
print 'return3={}'.format(eval(env,agent))
agent.policy[:]=1
print 'return6={}'.format(eval(env,agent))
agent.policy[97:100]=0
print 'return_ensemble={}'.format(eval(env,agent))

return3=49
return6=68
return_ensemble=70

def policy_evaluation(self):
# iterative eval
while True:
# one iteration
new_value_pi = self.value_pi.copy()
for i in range(1, self.state_num): # for each state
value_sas = []
for j in range(0, self.act_num): # for each act
value_sa = np.dot(self.table[j, i, :], self.reward + self.gamma * self.value_pi)
value_sas.append(value_sa)
new_value_pi[i] = value_sas[self.policy[i]]
diff = np.sqrt(np.sum(np.power(self.value_pi - new_value_pi, 2)))
if diff < 1e-6:
break
else:
self.value_pi = new_value_pi

def policy_improvement(self):
new_policy = np.zeros_like(self.policy)
for i in range(1, self.state_num):
for j in range(0, self.act_num):
self.value_q[i,j] = np.dot(self.table[j,i,:], self.reward + self.gamma * self.value_pi)
# update policy
max_act = np.argmax(self.value_q[i,:])
new_policy[i] = max_act
if np.all(np.equal(new_policy, self.policy)):
return False
else:
self.policy = new_policy
return True

def policy_iteration(self):
iteration = 0
while True:
iteration += 1
self.policy_evaluation()
ret = self.policy_improvement()
if not ret:
break
print 'Iter {} rounds converge'.format(iteration)

def policy_iteration_demo():
env = Snake(0, [3,6])
agent = TableAgent(env.state_transition_table(), env.reward_table())
agent.policy_iteration()
print 'return_pi={}'.format(eval(env,agent))
print agent.policy

Iter 2 rounds converge
return_pi=70
[0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0]

### 3.3 策略迭代的展示

def policy_iteration_demo():
env = Snake(10, [3,6])
agent = TableAgent(env.state_transition_table(), env.reward_table())
print 'return3={}'.format(eval(env,agent))
agent.policy[:]=1
print 'return6={}'.format(eval(env,agent))
agent.policy[97:100]=0
print 'return_ensemble={}'.format(eval(env,agent))
agent.policy_iteration()
print 'return_pi={}'.format(eval(env,agent))
print agent.policy

return3=-45
return6=21
return_ensemble=31
return_pi=41
[0 1 0 0 0 1 1 1 1 1 0 0 0 0 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 0 1 0 1 1 1 1 1
1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 0 0 0 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 1 1 1 0 0 0]