缣移双 发表于 2025-6-1 21:41:15

DQN算法

在Q-learning的学习过程中,我们需要维护一个 |S|x|A| 的Q表,当任务的状态空间和动作空间过大时,空间复杂度和时间复杂度都太高,为了解决这个问题,DQN采用神经网络来代替Q表,输入状态,预估该状态下采用不同动作的Q值
神经网络本身不是DQN的精髓,神经网络可以设计成MLP也可以设计成CNN等等,DQN的巧妙之处在于两个网络、经验回放等trick
 
Trick 1:两个网络

DQN算法采用了2个神经网络,分别是evaluate network(Q值网络)和target network(目标网络),两个网络结构完全相同:

[*]evaluate network用用来计算策略选择的Q值和Q值迭代更新,梯度下降、反向传播的也是evaluate network
[*]target network用来计算TD Target中下一状态的Q值,网络参数更新来自evaluate network网络参数复制
 设计target network目的是为了保持目标值稳定,防止过拟合,从而提高训练过程稳定和收敛速度
 
Trick 2:经验回放Experience Replay

DQN算法设计了一个固定大小的记忆库memory,用来记录经验,经验是一条一条的observation或者说是transition,它表示成  ,含义是当前状态→当前状态采取的动作→获得的奖励→转移到下一个状态
一开始记忆库memory中没有经验,也没有训练evaluate network,积累了一定数量的经验之后,再开始训练evaluate network。记忆库memory中的经验可以是自己历史的经验(epsilon-greedy得到的经验),也可以学习其他人的经验。训练evaluate network的时候,是从记忆库memory中随机选择batch size大小的经验,喂给evaluate network
设计记忆库memory并且随机选择经验喂给evaluate network的技巧打破了相邻训练样本之间相关性,试着想下,状态→动作→奖励→下一个状态的循环是具有关联的,用相邻的样本连续训练evaluate network会带来网络过拟合泛化能力差的问题,而经验回放技巧增强了训练样本之间的独立性
 
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

# 定义DQN网络
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim):
      super(DQN, self).__init__()
      self.fc1 = nn.Linear(input_dim, 64)
      self.fc2 = nn.Linear(64, output_dim)

    def forward(self, x):
      x = torch.relu(self.fc1(x))
      x = self.fc2(x)
      return x

# 定义DQN智能体
class DQNAgent:
    def __init__(self, state_dim, action_dim):
      self.state_dim = state_dim
      self.action_dim = action_dim
      self.gamma = 0.99# 折扣因子
      self.epsilon = 1.0# 探索率
      self.epsilon_min = 0.01
      self.epsilon_decay = 0.995
      self.learning_rate = 0.001
      self.memory = deque(maxlen=2000)
      self.model = DQN(state_dim, action_dim)
      self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
      self.criterion = nn.MSELoss()

    def remember(self, state, action, reward, next_state, done):
      self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
      if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_dim)
      state = torch.FloatTensor(state).unsqueeze(0)
      q_values = self.model(state)
      action = torch.argmax(q_values, dim=1).item()
      return action

    def replay(self, batch_size):
      if len(self.memory) < batch_size:
            return
      minibatch = random.sample(self.memory, batch_size)
      for state, action, reward, next_state, done in minibatch:
            state = torch.FloatTensor(state).unsqueeze(0)
            next_state = torch.FloatTensor(next_state).unsqueeze(0)
            target = reward
            if not done:
                target = (reward + self.gamma * torch.max(self.model(next_state)).item())
            target_f = self.model(state)
            target_f = target
            self.optimizer.zero_grad()
            output = self.model(state)
            loss = self.criterion(output, target_f)
            loss.backward()
            self.optimizer.step()
      if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

# 训练函数
def train_dqn(agent, env, episodes=500, batch_size=32):
    for episode in range(episodes):
      state = env.reset()
      if isinstance(state, tuple):
            state = state
      state = np.eye(env.observation_space.n)
      total_reward = 0
      done = False
      while not done:
            action = agent.act(state)
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            next_state = np.eye(env.observation_space.n)
            agent.remember(state, action, reward, next_state, done)
            agent.replay(batch_size)
            state = next_state
            total_reward += reward
      print(f"Episode {episode + 1}: Total Reward = {total_reward}")

# 主函数
if __name__ == "__main__":
    env = gym.make('CliffWalking-v0')
    state_dim = env.observation_space.n
    action_dim = env.action_space.n
    agent = DQNAgent(state_dim, action_dim)
    train_dqn(agent, env)
    env.close()DQN 
参考资料

DQN基本概念和算法流程(附Pytorch代码)
 

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
页: [1]
查看完整版本: DQN算法