[강화학습] DQN 알고리즘을 구현해보자.

2023. 12. 14. 21:02카테고리 없음

DQN 알고리즘은 Q-learning을 심층신경망을 이용해 근사하는 방식입니다. 하지만, 심층신경망만을 이용한다고해서 성능이 좋지는 않아서 target network와 experience replay를 사용해 뛰어난 성능을 높였습니다.

 

$$Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha [R_{t+1} + \gamma \max_{a} Q(S_{t+1}, a) - Q(S_t, A_t)]
$$

target network와 experience replay를 사용하지않은 DQN을 naive DQN이라고합니다.

두 기술을 사용하지 않았을 시에는 target이 계속 움직여 학습이 불안정하다는 문제와 temporal correlation의 문제가 있었습니다. 계속 움직여서 생기는 문제는 target network를 이용해 움직이지 않도록 함으로써 해결하였고, temporal correlation은 데이터를 많이 모아 무작위로 뽑는 experience replay를 통해 해결하였다.

 

사실 이러한 간단한 알고리즘은 로봇에는 적용하기 힘들 것 같아 구현도 하지 않고 개념만 공부하고 넘어갔었다. 그런데 왠걸! 강화학습으로 로봇을 제어하는 논문을 읽어보았는데, DQN을 변형해서 학습시키는 것을 알게되었다. 그리고, DDPG를 구현할 때 DQN과 상당부분 유사하기때문에 구현해보면 좋다고 생각한다.(혹시 그냥 넘어갈려는 분들에게 드리는 팁..)

 

알고리즘에 관한 설명은 많으니 짧게만 설명하고 바로 구현으로 넘어가겠습니다.

구현 코드는 공식 Pytorch 문서의 코드를 참고하였습니다.

https://tutorials.pytorch.kr/intermediate/reinforcement_q_learning.html

(굉장히 많은 것들을 배울 수 있었던 코드. torch.no_grad()부터, tuple unpacking, namedtuple등등. 하지만 DDPG를 구현할 때 동일하게 해보았는데 Replaybuffer에서 샘플을 뽑는 것이 너무 느려서(약 700배 차이였던거 같다) TD3를 만든 사람의 코드를 보고 DDPG코드를 바꿨습니다. DDPG글도 많이 봐주세요~)

 

ReplayBuffer는 다음과 같이 간단하게 구현하였습니다.

class ReplayBuffer:
    def __init__(self, max_length):
        self.buffer = collections.deque(maxlen=max_length)
        
    def put_data(self,*args):
        self.buffer.append(Transition(*args))
        
    def sample_minibatch(self):
        return random.sample(self.buffer, BATCH_SIZE)
        
    def __len__(self):
        return len(self.buffer)

 

 

네트워크는 다음과 같이 구성하였습니다.

class Qnetwork(nn.Module):
    def __init__(self, obs_space_dims:int, action_space_dims:int):
        super(Qnetwork,self).__init__()
        
        hidden_layer1 = 128
        hidden_layer2 = 128
        
        self.net = nn.Sequential(
            nn.Linear(obs_space_dims, hidden_layer1),
            nn.ReLU(),
            nn.Linear(hidden_layer1, hidden_layer2),
            nn.ReLU(),
            nn.Linear(hidden_layer2, action_space_dims)
        )
    
    def forward(self, state:torch.Tensor):
        q_value = self.net(state)
        return q_value

 

 

DQN은 에이전트 클래스로 다음과 같이 구성하였습니다.

sample_action함수에서는 random.random()함수로 0~1사이의 값을 얻습니다. 이때, threshold보다 작으면 아무 액션이나 선택하고(exploration), 크다면 state를 네트워크에 통과시켜 Q값을 얻은 후에 가장 큰 값을 고릅니다.

torch.tensor에 max함수를 사용하면 최대값과 최대값의 인덱스를 반환합니다. max함수에 주어진 인수는 차원을 뜻합니다.

class DQN:
    def __init__(self, obs_space_dims:int, action_space_dims:int):

        self.action_space_dims = action_space_dims
        
        self.qNet = Qnetwork(obs_space_dims, action_space_dims)
        self.qtargetNet = Qnetwork(obs_space_dims, action_space_dims)
        self.qtargetNet.load_state_dict(self.qNet.state_dict()) # copy
        self.optimizer = optim.AdamW(self.qNet.parameters(), lr=LEARNING_RATE)
        
        self.buffer = ReplayBuffer(BUFFER_SIZE)

    def sample_action(self, state:torch.Tensor, threshold:float):
        coin = random.random()
        if coin > threshold:
            with torch.no_grad():
                return self.qNet(state).max(1)[1].view(1,1)
        else:
            sample = random.sample([0,1], 1)
            return torch.tensor([sample],dtype=torch.long)
            
    def update(self):
        if len(self.buffer) < BATCH_SIZE:
            return
        transitions = self.buffer.sample_minibatch() # transition은 tuple로 이루어진 list
        batch = Transition(*zip(*transitions))
        
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                                 batch.next_state)), dtype=torch.bool)
        non_final_next_state = torch.cat([s for s in batch.next_state
                                             if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        
        q_values = self.qNet(state_batch).gather(1, action_batch)
        
        next_q_values = torch.zeros(BATCH_SIZE)
        with torch.no_grad():
            next_q_values[non_final_mask] = self.qtargetNet(non_final_next_state).max(1)[0]
            
        target = reward_batch + GAMMA * next_q_values
        loss = F.smooth_l1_loss(target.unsqueeze(1), q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 목표 네트워크의 가중치를 소프트 업데이트
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = self.qtargetNet.state_dict()
        policy_net_state_dict = self.qNet.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        self.qtargetNet.load_state_dict(target_net_state_dict)

 

update함수에서는 충분한 데이터가 모이기 전까지 학습을 하지 않습니다. 이는 적은 데이터로 성급하게 일반화하는 것을 방지합니다.

이 코드에서 transitions는 namedtuple로 이루어진 리스트입니다. 이를 튜플 언패킹(tuple unpacking)을 통해 zip함수에 전달합니다. 그러면 zip 함수는 (s1,a1,r2,s2), (s2,a2,r3,s3), (s3,a3,r4,s4)로 되어있던 데이터를 (s1,s2,s3), (a1,a2,a3), (r2,r3,r4), (s2, s3, s4)로 만들어줍니다. zip함수는 list를 반환하는데 이를 namedtuple인 Transition에 넣기위해 다시 언패킹해서 전달합니다. 

batch = Transition(*zip(*transitions))

 

main함수에서 terminated가 되었을 때는 next_state를 None으로 처리했습니다. 이 경우 next_state의 Q값을 0으로 해주어야하는데 mask를 만들어 처리하면 쉽게 처리할 수 있습니다.

        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                                 batch.next_state)), dtype=torch.bool)
        non_final_next_state = torch.cat([s for s in batch.next_state
                                             if s is not None])

 

DQN에서는 C스텝마다 업데이트를 해주었지만, 소프트 업데이트가 더 좋은 거 같아 소프트 업데이트를 적용했습니다.

        # 목표 네트워크의 가중치를 소프트 업데이트
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = self.qtargetNet.state_dict()
        policy_net_state_dict = self.qNet.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        self.qtargetNet.load_state_dict(target_net_state_dict)

 

위 코드를 이해하셨다면 main함수는 별 거 없습니다. 제 코드에서는 Linear anneling을 통해 epsilon값을 조정해주었고, 참고코드에서는 비선형함수로 조정해주었습니다.

def main():
    
    #env = gym.make('CartPole-v1', render_mode = 'human') # 학습하는 것을 보고 싶으시다면
    env = gym.make('CartPole-v1')
    observation_space_dims = env.observation_space.shape[0]
    action_space_dims = env.action_space.n
    
    for seed in [2,3,5]:
        
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        
        agent = DQN(observation_space_dims, action_space_dims)
        
        # 출력변수
        print_interval = 50
        score = 0.0
        
        for n_episode in range(EPISODE):
            state, _ = env.reset(seed=seed)
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0) # size [1,4]
            epsilon = max(0.01, 0.08 - 0.01*(n_episode/200)) #Linear annealing from 8% to 1% (n_episod = 1400 -> 1%)
            done = False
            
            while not done:
                action = agent.sample_action(state, epsilon) # size[1,1]
                observation, reward, terminated, truncated, _ = env.step(action.item())
                done = terminated or truncated
                reward = torch.tensor([reward]) # size [1]
                
                if terminated:
                    next_state = None
                else:
                    next_state = torch.tensor(observation, dtype=torch.float32).unsqueeze(0) # size [1,4]
                
                agent.buffer.put_data(state,action,reward,next_state)
                agent.update()
                
                state = next_state
                score += reward.item()
                
                if done:
                    break
                
            
            if n_episode % print_interval == 0 and n_episode != 0:
                print("seed : %d, episode = %d, avg_score : %.2f"
                      %(seed, n_episode, score/print_interval))
                score = 0

if __name__ == '__main__':
    main()