Python
MoG-DQN Implementation In PyTorch To Solve Lunar Lander

MoG-DQN Implementation In PyTorch To Solve Lunar Lander

Distributional Deep Reinforcement Learning with a mixture of Gaussians (MoG-DQN) improves performance and stability of deep reinforcement learning by using a distributional perspective of the value function.

The value distribution is modeled as mixture of Gaussians instead of a single point estimate or a simple categorical distibution.

Gaussian Mixture Model

+--------------------------------------------------+
|                GaussianMixtureModel              |
+--------------------------------------------------+
| - state_dim: int                                 |
| - action_dim: int                                |
| - num_components: int                            |
| - fc1: nn.Linear                                 |
| - fc2: nn.Linear                                 |
| - mean: nn.Linear                                |
| - log_var: nn.Linear                             |
| - logits: nn.Linear                              |
+--------------------------------------------------+
| + __init__(self, state_dim: int, action_dim: int,|
|            num_components: int, hidden_dim=256)  |
| + forward(self, state: torch.Tensor) ->          |
|        Tuple[torch.Tensor, torch.Tensor,         |
|              torch.Tensor]                       |
| + get_distribution(self, state: torch.Tensor) -> |
|        Tuple[torch.Tensor, torch.Tensor,         |
|              torch.Tensor]                       |
+--------------------------------------------------+

Attributes:

  • state_dim: int – Dimension of the input state.
  • action_dim: int – Dimension of the action space.
  • num_components: int – Number of components in the Gaussian mixture model.
  • fc1: nn.Linear – First fully connected layer with input size state_dim and output size 256.
  • fc2: nn.Linear – Second fully connected layer with input size 256 and output size 128.
  • mean: nn.Linear – Fully connected layer that outputs the means of the Gaussian components with input size 128 and output size action_dim * num_components.
  • log_var: nn.Linear – Fully connected layer that outputs the log variances of the Gaussian components with input size 128 and output size action_dim * num_components.
  • logits: nn.Linear – Fully connected layer that outputs the logits for the mixture weights with input size 128 and output size action_dim * num_components.
  • Methods:
    • __init__(self, state_dim: int, action_dim: int, num_components: int, hidden_dim=256): Constructor to initialize the model.
    • forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Method to perform a forward pass through the network and return the mean, log variance, and logits.
    • get_distribution(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Method to get the distribution parameters from the state.

Code Implementation of Above Class

class GaussianMixtureModel(nn.Module):
    def __init__(self, state_dim, action_dim,num_components,hidden_dim=256):
        ...

    def forward(self, state):

       ...

    def get_distribution(self, state):
        ...

The Constructor

def __init__(self, state_dim, action_dim,num_components,hidden_dim=256):
        super(GaussianMixtureModel, self).__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.num_components = num_components

        self.fc1 = nn.Linear(state_dim,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
      
        #output action_dim * num_components   (4*5)
        self.mean = nn.Linear(hidden_dim, action_dim * num_components)
        self.log_var = nn.Linear(hidden_dim, action_dim * num_components)
        self.logits = nn.Linear(hidden_dim, action_dim * num_components)

Constructor method to initialize the model with state dimension, action dimension, hidden dimension and number of components.

  • fc1: First fully connected layer,
  • fc2: Second fully connected layer
  • Mean : Also a fully connected layer. It outputs the means of the Gaussian components
  • log-var: Also a fully connected layer. It outputs a lag variances of the Gaussian components
  • logits: It outputs logits for the mixture of weights. Logits are raw unnormalized outputs of a neural network

Forward Method

def forward(self, state):

        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        #mean is reshaped from (1,20) to (1,4,5)and so is log_var and logits
        mean = self.mean(x).view(-1, self.action_dim, self.num_components)
        log_var = self.log_var(x).view(-1, self.action_dim, self.num_components)
        log_var = torch.clamp(log_var, -10, 10)  # Clipping log variance for stability
        logits = self.logits(x).view(-1, self.action_dim, self.num_components)
        return mean, log_var, logits

This method defines the forward pass of the model, returns mean, log variances and logits of the Gaussian components.

Distribution method

 def get_distribution(self, state):
        
        mean, log_var, logits = self.forward(state)
        return mean, torch.exp(log_var), torch.softmax(logits, dim=-1)

Method that computes and returns the means, variances and mixture weights by passing the state through the forward method.

Variance is obtained by exponentiating the log variance to ensure positivity.

Mixture weights are obtained by applying softmax to logits to ensure they form a valid probability distribution.

Replay Buffer

+----------------------+
|    ReplayBuffer      |
+----------------------+
| - buffer: list       |
| - capacity: int      |
+----------------------+
| + __init__(capacity: int)   |
| + push(experience: tuple)   |
| + sample(batch_size: int): list  |
+----------------------+
  • Attributes:
    • buffer: A list to hold the stored experiences.
    • capacity: An integer specifying the maximum number of experiences the buffer can store.
  • Methods:
    • __init__(capacity: int): Initializes the buffer list and sets the capacity.
    • push(experience: tuple): Adds an experience to the buffer. If the buffer exceeds its capacity, it removes the oldest experience.
    • sample(batch_size: int) -> list: Returns a random sample of experiences from the buffer, with the sample size specified by batch_size.

Code for Replay Buffer

# Experience replay buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = []
        self.capacity = capacity

    def push(self, experience):
        if len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
        self.buffer.append(experience)

    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        return [self.buffer[idx] for idx in indices]

Loss Function

The compute_loss function is designed to calculate the loss for a Gaussian mixture model in the context of reinforcement learning. The loss is computed based on the negative log likelihood of the target value under the predicted Gaussian mixture distribution.

def compute_loss(mean, var, logits, target):
    
    m = torch.distributions.Normal(mean, var)
    log_prob = m.log_prob(target.unsqueeze(-1).expand_as(mean))
    log_prob = torch.sum(log_prob, dim=-2) + torch.log(logits + 1e-10)  
    loss = -torch.logsumexp(log_prob, dim=-1).mean()
   
    return loss

Distribution Initialization:

m = torch.distributions.Normal(mean, var)

This line creates a normal distribution object m with the given mean and var (standard deviation) for each Gaussian component in the mixture. The Normal distribution in PyTorch takes mean and standard deviation (not variance), so ensure var is actually standard deviation.

Log Probability Calculation:

log_prob = m.log_prob(target.unsqueeze(-1).expand_as(mean))
  • This calculates the log probability of the target under each Gaussian component.
  • target.unsqueeze(-1).expand_as(mean) reshapes and expands the target tensor to match the shape of mean for broadcasting.
  • m.log_prob returns the log probability of target for each Gaussian component.

Summing Log Probabilities and Adding Logits:

log_prob = torch.sum(log_prob, dim=-2) + torch.log(logits + 1e-10)
  • torch.sum(log_prob, dim=-2) sums the log probabilities across the dimensions corresponding to the action dimensions, since each action dimension is independent.
  • torch.log(logits + 1e-10) adds the log of the mixture weights to the log probabilities. 1e-10 is added to the logits to prevent taking the log of zero for numerical stability.

Computing the Loss:

loss = -torch.logsumexp(log_prob, dim=-1).mean()
  • torch.logsumexp(log_prob, dim=-1) computes the log-sum-exp of the log probabilities across the mixture components. This effectively computes the log of the sum of the exponentiated log probabilities, which represents the overall log likelihood of the target under the mixture model.
  • -.mean() takes the negative mean of these values to compute the final loss. The negative sign is used because we’re typically interested in minimizing the negative log likelihood.

Load and Save Function

def save_checkpoint(state, filename='checkpoint.pth'):
    torch.save(state, filename)
def load_checkpoint(filename='checkpoint.pth', map_location=None):
    if map_location:
        return torch.load(filename, map_location=map_location)
    return torch.load(filename)

Setting Device

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Setting Hyperparameters

# Hyperparameters
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
hidden_dim = 128
num_components = 5
learning_rate = 0.0005
num_episodes = 2000
gamma = 0.99
batch_size = 64
buffer_capacity = 10000
target_update_freq = 5

# Exploration parameters
epsilon_start = 1.0
epsilon_end = 0.01
epsilon_decay = 0.995

episode_rewards = []
epsilon = epsilon_start

Initialize

buffer = ReplayBuffer(buffer_capacity)
# Initialize model and target model
model = GaussianMixtureModel(state_dim, action_dim, num_components,hidden_dim)
target_model = GaussianMixtureModel(state_dim, action_dim, num_components,hidden_dim)
# Move the model to the chosen device
model.to(device)
target_model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

Load model if exists

checkpoint_path = 'gdrl_june_19.pth'
try:
    map_location = torch.device('cpu') if not torch.cuda.is_available() else None
    checkpoint = load_checkpoint(checkpoint_path, map_location=map_location)
    model.load_state_dict(checkpoint['main_net_state_dict'])
    target_model.load_state_dict(checkpoint['target_net_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epsilon = checkpoint['epsilon']
    start_episode = checkpoint['episode'] + 1
    print(f"Loaded checkpoint from episode {start_episode}")
except FileNotFoundError:
    print("No checkpoint found, starting from scratch.")
target_model.load_state_dict(model.state_dict())

Training Loop

Initialization and Loop Structure

for episode in range(num_episodes):
    state = env.reset()
    total_reward = 0
  • Episode Loop: The outer loop runs for a specified number of episodes. Each episode starts with resetting the environment and initializing total_reward.

Time Step Loop

for t in range(1000):
    state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
    mean, var, logits = model.get_distribution(state_tensor)

    action_probs = mean.mean(dim=-1).softmax(dim=-1).cpu().detach().numpy()

    if np.isnan(action_probs).any():
        action_probs = np.nan_to_num(action_probs, nan=1.0/action_dim)
        action_probs /= action_probs.sum()

    if np.random.rand() < epsilon:
        action = np.random.choice(action_dim)
    else:
        action = np.argmax(action_probs[0])

    next_state, reward, done, _ = env.step(action)
    buffer.push((state, action, reward, next_state, done))
    state = next_state
    total_reward += reward
  • State Tensor Conversion: The current state is converted to a PyTorch tensor.
  • Get Distribution: The Gaussian mixture model outputs mean, var, and logits for the current state.
  • Action Probabilities: The mean of the Gaussian components is softmaxed to get action probabilities.
  • Handle NaN Values: In case action_probs contains NaNs, they are replaced with a uniform distribution and re-normalized.
  • ε-Greedy Strategy: With probability epsilon, a random action is chosen (exploration). Otherwise, the action with the highest probability is chosen (exploitation).
  • Environment Step: The chosen action is taken in the environment, resulting in the next state, reward, and a done flag indicating whether the episode is finished.
  • Experience Storage: The experience (state, action, reward, next state, done) is stored in the replay buffer.
  • Update State and Reward: The current state is updated, and the reward is added to total_reward.

Training the Model

if len(buffer.buffer) >= batch_size:
    batch = buffer.sample(batch_size)
    states, actions, rewards, next_states, dones = zip(*batch)

    states_tensor = torch.tensor(states, dtype=torch.float32)
    actions_tensor = torch.tensor(actions, dtype=torch.long)
    rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
    next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
    dones_tensor = torch.tensor(dones, dtype=torch.float32)

    mean, var, logits = model.get_distribution(states_tensor)
    with torch.no_grad():
        next_mean, next_var, next_logits = target_model.get_distribution(next_states_tensor)

    target = rewards_tensor + gamma * torch.max(next_mean.mean(dim=-1), dim=1).values * (1 - dones_tensor)

    mean = mean.gather(1, actions_tensor.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, num_components)).squeeze(1)
    var = var.gather(1, actions_tensor.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, num_components)).squeeze(1)
    logits = logits.gather(1, actions_tensor.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, num_components)).squeeze(1)
    std_dev = var.sqrt()
    loss = compute_loss(mean, std_dev, logits, target)

    optimizer.zero_grad()
    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    optimizer.step()
  • Check Buffer Size: Training occurs only if the buffer has enough experiences (batch_size).
  • Sample Batch: A batch of experiences is sampled from the replay buffer.
  • Convert to Tensors: States, actions, rewards, next states, and dones are converted to PyTorch tensors.
  • Get Distribution for Current States: The model predicts mean, var, and logits for the current states.
  • Get Distribution for Next States: The target model predicts mean, var, and logits for the next states, using torch.no_grad() to prevent gradient computation.
  • Compute Target: The target Q-value is computed using the Bellman equation, considering the rewards and the discounted maximum Q-value of the next states, adjusted by the done flag.
  • Gather Q-Values: The predicted means, variances, and logits are gathered for the taken actions.
  • Compute Loss: The loss is computed using the compute_loss function.
  • Backpropagation: Gradients are computed and backpropagated.
  • Gradient Clipping: Gradients are clipped to a maximum norm of 1.0 to prevent exploding gradients.
  • Optimizer Step: The model parameters are updated.

Episode End and Epsilon Decay

if done:
    break

Check if Episode is Done: If the done flag is True, the loop breaks, ending the episode.

epsilon = max(epsilon_end, epsilon_decay * epsilon)
episode_rewards.append(total_reward)
  • Decay Epsilon: epsilon is decayed after each episode, ensuring exploration decreases over time.
  • Store Total Reward: The total reward for the episode is appended to episode_rewards.

Target Network Update and Checkpoint Saving

if episode % target_update_freq == 0:
    target_model.load_state_dict(model.state_dict())
    print(f"Episode {episode}, Total Reward: {total_reward}")

Update Target Network: The target network is updated with the current model’s parameters every target_update_freq episodes.

if episode % 50 == 0:
    save_checkpoint({
        'episode': episode,
        'main_net_state_dict': model.state_dict(),
        'target_net_state_dict': target_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epsilon': epsilon
    }, checkpoint_path)
    print(f"Checkpoint saved at episode {episode}")

Save Checkpoint: A checkpoint of the model, target model, optimizer, and epsilon value is saved every 50 episodes.

Early Stopping

if sum(episode_rewards[-10:]) > 1500:
    print("Training done")
    save_checkpoint({
        'episode': episode,
        'main_net_state_dict': model.state_dict(),
        'target_net_state_dict': target_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epsilon': epsilon
    }, checkpoint_path)
    print(f"Checkpoint saved at episode {episode}")
    break
  • Early Stopping: If the sum of rewards for the last 10 episodes exceeds 1500, training is stopped early, and a checkpoint is saved.

In summary, The training loop involves interaction with the environment, storing experiences in the replay buffer, sampling and training on batches of experiences, updating the model, decaying the exploration rate, and periodically updating the target network and saving checkpoints. This structured approach helps in stable and efficient learning of the reinforcement learning agent.

Github Link to Code

Video