MoG-DQN Implementation In PyTorch To Solve Lunar Lander
Distributional Deep Reinforcement Learning with a mixture of Gaussians (MoG-DQN) improves performance and stability of deep reinforcement learning by using a distributional perspective of the value function.
The value distribution is modeled as mixture of Gaussians instead of a single point estimate or a simple categorical distibution.
Gaussian Mixture Model
+--------------------------------------------------+
| GaussianMixtureModel |
+--------------------------------------------------+
| - state_dim: int |
| - action_dim: int |
| - num_components: int |
| - fc1: nn.Linear |
| - fc2: nn.Linear |
| - mean: nn.Linear |
| - log_var: nn.Linear |
| - logits: nn.Linear |
+--------------------------------------------------+
| + __init__(self, state_dim: int, action_dim: int,|
| num_components: int, hidden_dim=256) |
| + forward(self, state: torch.Tensor) -> |
| Tuple[torch.Tensor, torch.Tensor, |
| torch.Tensor] |
| + get_distribution(self, state: torch.Tensor) -> |
| Tuple[torch.Tensor, torch.Tensor, |
| torch.Tensor] |
+--------------------------------------------------+
Attributes:
state_dim: int
– Dimension of the input state.action_dim: int
– Dimension of the action space.num_components: int
– Number of components in the Gaussian mixture model.fc1: nn.Linear
– First fully connected layer with input sizestate_dim
and output size 256.fc2: nn.Linear
– Second fully connected layer with input size 256 and output size 128.mean: nn.Linear
– Fully connected layer that outputs the means of the Gaussian components with input size 128 and output sizeaction_dim * num_components
.log_var: nn.Linear
– Fully connected layer that outputs the log variances of the Gaussian components with input size 128 and output sizeaction_dim * num_components
.logits: nn.Linear
– Fully connected layer that outputs the logits for the mixture weights with input size 128 and output sizeaction_dim * num_components
.
- Methods:
__init__(self, state_dim: int, action_dim: int, num_components: int, hidden_dim=256)
: Constructor to initialize the model.forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
: Method to perform a forward pass through the network and return the mean, log variance, and logits.get_distribution(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
: Method to get the distribution parameters from the state.
Code Implementation of Above Class
class GaussianMixtureModel(nn.Module):
def __init__(self, state_dim, action_dim,num_components,hidden_dim=256):
...
def forward(self, state):
...
def get_distribution(self, state):
...
The Constructor
def __init__(self, state_dim, action_dim,num_components,hidden_dim=256):
super(GaussianMixtureModel, self).__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.num_components = num_components
self.fc1 = nn.Linear(state_dim,hidden_dim)
self.fc2 = nn.Linear(hidden_dim,hidden_dim)
#output action_dim * num_components (4*5)
self.mean = nn.Linear(hidden_dim, action_dim * num_components)
self.log_var = nn.Linear(hidden_dim, action_dim * num_components)
self.logits = nn.Linear(hidden_dim, action_dim * num_components)
Constructor method to initialize the model with state dimension, action dimension, hidden dimension and number of components.
- fc1: First fully connected layer,
- fc2: Second fully connected layer
- Mean : Also a fully connected layer. It outputs the means of the Gaussian components
- log-var: Also a fully connected layer. It outputs a lag variances of the Gaussian components
- logits: It outputs logits for the mixture of weights. Logits are raw unnormalized outputs of a neural network
Forward Method
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
#mean is reshaped from (1,20) to (1,4,5)and so is log_var and logits
mean = self.mean(x).view(-1, self.action_dim, self.num_components)
log_var = self.log_var(x).view(-1, self.action_dim, self.num_components)
log_var = torch.clamp(log_var, -10, 10) # Clipping log variance for stability
logits = self.logits(x).view(-1, self.action_dim, self.num_components)
return mean, log_var, logits
This method defines the forward pass of the model, returns mean, log variances and logits of the Gaussian components.
Distribution method
def get_distribution(self, state):
mean, log_var, logits = self.forward(state)
return mean, torch.exp(log_var), torch.softmax(logits, dim=-1)
Method that computes and returns the means, variances and mixture weights by passing the state through the forward method.
Variance is obtained by exponentiating the log variance to ensure positivity.
Mixture weights are obtained by applying softmax to logits to ensure they form a valid probability distribution.
Replay Buffer
+----------------------+
| ReplayBuffer |
+----------------------+
| - buffer: list |
| - capacity: int |
+----------------------+
| + __init__(capacity: int) |
| + push(experience: tuple) |
| + sample(batch_size: int): list |
+----------------------+
- Attributes:
buffer
: A list to hold the stored experiences.capacity
: An integer specifying the maximum number of experiences the buffer can store.
- Methods:
__init__(capacity: int)
: Initializes the buffer list and sets the capacity.push(experience: tuple)
: Adds an experience to the buffer. If the buffer exceeds its capacity, it removes the oldest experience.sample(batch_size: int) -> list
: Returns a random sample of experiences from the buffer, with the sample size specified bybatch_size
.
Code for Replay Buffer
# Experience replay buffer
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = []
self.capacity = capacity
def push(self, experience):
if len(self.buffer) >= self.capacity:
self.buffer.pop(0)
self.buffer.append(experience)
def sample(self, batch_size):
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
return [self.buffer[idx] for idx in indices]
Loss Function
The compute_loss
function is designed to calculate the loss for a Gaussian mixture model in the context of reinforcement learning. The loss is computed based on the negative log likelihood of the target value under the predicted Gaussian mixture distribution.
def compute_loss(mean, var, logits, target):
m = torch.distributions.Normal(mean, var)
log_prob = m.log_prob(target.unsqueeze(-1).expand_as(mean))
log_prob = torch.sum(log_prob, dim=-2) + torch.log(logits + 1e-10)
loss = -torch.logsumexp(log_prob, dim=-1).mean()
return loss
Distribution Initialization:
m = torch.distributions.Normal(mean, var)
This line creates a normal distribution object m
with the given mean
and var
(standard deviation) for each Gaussian component in the mixture. The Normal
distribution in PyTorch takes mean and standard deviation (not variance), so ensure var
is actually standard deviation.
Log Probability Calculation:
log_prob = m.log_prob(target.unsqueeze(-1).expand_as(mean))
- This calculates the log probability of the
target
under each Gaussian component. target.unsqueeze(-1).expand_as(mean)
reshapes and expands the target tensor to match the shape ofmean
for broadcasting.m.log_prob
returns the log probability oftarget
for each Gaussian component.
Summing Log Probabilities and Adding Logits:
log_prob = torch.sum(log_prob, dim=-2) + torch.log(logits + 1e-10)
torch.sum(log_prob, dim=-2)
sums the log probabilities across the dimensions corresponding to the action dimensions, since each action dimension is independent.torch.log(logits + 1e-10)
adds the log of the mixture weights to the log probabilities.1e-10
is added to the logits to prevent taking the log of zero for numerical stability.
Computing the Loss:
loss = -torch.logsumexp(log_prob, dim=-1).mean()
torch.logsumexp(log_prob, dim=-1)
computes the log-sum-exp of the log probabilities across the mixture components. This effectively computes the log of the sum of the exponentiated log probabilities, which represents the overall log likelihood of the target under the mixture model.-.mean()
takes the negative mean of these values to compute the final loss. The negative sign is used because we’re typically interested in minimizing the negative log likelihood.
Load and Save Function
def save_checkpoint(state, filename='checkpoint.pth'):
torch.save(state, filename)
def load_checkpoint(filename='checkpoint.pth', map_location=None):
if map_location:
return torch.load(filename, map_location=map_location)
return torch.load(filename)
Setting Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Setting Hyperparameters
# Hyperparameters
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
hidden_dim = 128
num_components = 5
learning_rate = 0.0005
num_episodes = 2000
gamma = 0.99
batch_size = 64
buffer_capacity = 10000
target_update_freq = 5
# Exploration parameters
epsilon_start = 1.0
epsilon_end = 0.01
epsilon_decay = 0.995
episode_rewards = []
epsilon = epsilon_start
Initialize
buffer = ReplayBuffer(buffer_capacity)
# Initialize model and target model
model = GaussianMixtureModel(state_dim, action_dim, num_components,hidden_dim)
target_model = GaussianMixtureModel(state_dim, action_dim, num_components,hidden_dim)
# Move the model to the chosen device
model.to(device)
target_model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
Load model if exists
checkpoint_path = 'gdrl_june_19.pth'
try:
map_location = torch.device('cpu') if not torch.cuda.is_available() else None
checkpoint = load_checkpoint(checkpoint_path, map_location=map_location)
model.load_state_dict(checkpoint['main_net_state_dict'])
target_model.load_state_dict(checkpoint['target_net_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epsilon = checkpoint['epsilon']
start_episode = checkpoint['episode'] + 1
print(f"Loaded checkpoint from episode {start_episode}")
except FileNotFoundError:
print("No checkpoint found, starting from scratch.")
target_model.load_state_dict(model.state_dict())
Training Loop
Initialization and Loop Structure
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
- Episode Loop: The outer loop runs for a specified number of episodes. Each episode starts with resetting the environment and initializing
total_reward
.
Time Step Loop
for t in range(1000):
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
mean, var, logits = model.get_distribution(state_tensor)
action_probs = mean.mean(dim=-1).softmax(dim=-1).cpu().detach().numpy()
if np.isnan(action_probs).any():
action_probs = np.nan_to_num(action_probs, nan=1.0/action_dim)
action_probs /= action_probs.sum()
if np.random.rand() < epsilon:
action = np.random.choice(action_dim)
else:
action = np.argmax(action_probs[0])
next_state, reward, done, _ = env.step(action)
buffer.push((state, action, reward, next_state, done))
state = next_state
total_reward += reward
- State Tensor Conversion: The current state is converted to a PyTorch tensor.
- Get Distribution: The Gaussian mixture model outputs
mean
,var
, andlogits
for the current state. - Action Probabilities: The mean of the Gaussian components is softmaxed to get action probabilities.
- Handle NaN Values: In case
action_probs
contains NaNs, they are replaced with a uniform distribution and re-normalized. - ε-Greedy Strategy: With probability
epsilon
, a random action is chosen (exploration). Otherwise, the action with the highest probability is chosen (exploitation). - Environment Step: The chosen action is taken in the environment, resulting in the next state, reward, and a
done
flag indicating whether the episode is finished. - Experience Storage: The experience (state, action, reward, next state, done) is stored in the replay buffer.
- Update State and Reward: The current state is updated, and the reward is added to
total_reward
.
Training the Model
if len(buffer.buffer) >= batch_size:
batch = buffer.sample(batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states_tensor = torch.tensor(states, dtype=torch.float32)
actions_tensor = torch.tensor(actions, dtype=torch.long)
rewards_tensor = torch.tensor(rewards, dtype=torch.float32)
next_states_tensor = torch.tensor(next_states, dtype=torch.float32)
dones_tensor = torch.tensor(dones, dtype=torch.float32)
mean, var, logits = model.get_distribution(states_tensor)
with torch.no_grad():
next_mean, next_var, next_logits = target_model.get_distribution(next_states_tensor)
target = rewards_tensor + gamma * torch.max(next_mean.mean(dim=-1), dim=1).values * (1 - dones_tensor)
mean = mean.gather(1, actions_tensor.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, num_components)).squeeze(1)
var = var.gather(1, actions_tensor.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, num_components)).squeeze(1)
logits = logits.gather(1, actions_tensor.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, num_components)).squeeze(1)
std_dev = var.sqrt()
loss = compute_loss(mean, std_dev, logits, target)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
- Check Buffer Size: Training occurs only if the buffer has enough experiences (
batch_size
). - Sample Batch: A batch of experiences is sampled from the replay buffer.
- Convert to Tensors: States, actions, rewards, next states, and dones are converted to PyTorch tensors.
- Get Distribution for Current States: The model predicts
mean
,var
, andlogits
for the current states. - Get Distribution for Next States: The target model predicts
mean
,var
, andlogits
for the next states, usingtorch.no_grad()
to prevent gradient computation. - Compute Target: The target Q-value is computed using the Bellman equation, considering the rewards and the discounted maximum Q-value of the next states, adjusted by the
done
flag. - Gather Q-Values: The predicted means, variances, and logits are gathered for the taken actions.
- Compute Loss: The loss is computed using the
compute_loss
function. - Backpropagation: Gradients are computed and backpropagated.
- Gradient Clipping: Gradients are clipped to a maximum norm of 1.0 to prevent exploding gradients.
- Optimizer Step: The model parameters are updated.
Episode End and Epsilon Decay
if done:
break
Check if Episode is Done: If the done
flag is True, the loop breaks, ending the episode.
epsilon = max(epsilon_end, epsilon_decay * epsilon)
episode_rewards.append(total_reward)
- Decay Epsilon:
epsilon
is decayed after each episode, ensuring exploration decreases over time. - Store Total Reward: The total reward for the episode is appended to
episode_rewards
.
Target Network Update and Checkpoint Saving
if episode % target_update_freq == 0:
target_model.load_state_dict(model.state_dict())
print(f"Episode {episode}, Total Reward: {total_reward}")
Update Target Network: The target network is updated with the current model’s parameters every target_update_freq
episodes.
if episode % 50 == 0:
save_checkpoint({
'episode': episode,
'main_net_state_dict': model.state_dict(),
'target_net_state_dict': target_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epsilon': epsilon
}, checkpoint_path)
print(f"Checkpoint saved at episode {episode}")
Save Checkpoint: A checkpoint of the model, target model, optimizer, and epsilon value is saved every 50 episodes.
Early Stopping
if sum(episode_rewards[-10:]) > 1500:
print("Training done")
save_checkpoint({
'episode': episode,
'main_net_state_dict': model.state_dict(),
'target_net_state_dict': target_model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epsilon': epsilon
}, checkpoint_path)
print(f"Checkpoint saved at episode {episode}")
break
- Early Stopping: If the sum of rewards for the last 10 episodes exceeds 1500, training is stopped early, and a checkpoint is saved.
In summary, The training loop involves interaction with the environment, storing experiences in the replay buffer, sampling and training on batches of experiences, updating the model, decaying the exploration rate, and periodically updating the target network and saving checkpoints. This structured approach helps in stable and efficient learning of the reinforcement learning agent.