Master Snake Game AI with PPO: Step-by-Step Guide (Part II)
In this second part we will learn to build an AI agent based on Proximal Policy Optimization (PPO) algorithm.
The Algorithm
PPO_Algorithm:
Initialization:
- Define ActorCritic Network:
- Convolutional Layers:
- Conv1: Input channels=1, Output channels=3, Kernel size=3
- Initialize weights using Kaiming uniform
- Fully Connected Layers:
- FC1: Input=Calculated from convolution, Output=hidden_dim (128)
- FC2: Input=hidden_dim, Output=hidden_dim
- Output Layers:
- Action: Linear(hidden_dim, action_dim)
- Value: Linear(hidden_dim, 1)
- Optimizer: Adam (lr=0.0002)
- Activation Function: GELU, with custom Swish
- Define PPO Agent:
- Initialize ActorCritic and ActorCritic_old networks
- Define MSE Loss function
- Define memory buffer with max size
- Set hyperparameters:
- gamma: 0.99
- K_epochs: 4
- eps_clip: 0.2
- hidden_dim: 128
Training_Loop:
- For each episode:
- Reset environment and normalize state
- Initialize total reward
- Loop until done:
- Select action from policy_old network
- Step environment and collect:
- next_state, reward, done
- Normalize next_state
- Store (state, action, log_prob, reward, done) in memory buffer
- Update state and total reward
- If memory size exceeds threshold:
- Call update method
- Clear memory buffer
- Save total reward for the episode
- Periodically save the model
Update_Method:
- Calculate Discounted Rewards:
- For each timestep in memory:
- If done, set discounted_reward = 0
- Compute discounted_reward as reward + gamma * discounted_reward
- Normalize discounted rewards
- For each epoch in K_epochs:
- Recalculate action log probabilities using current policy
- Recalculate state values using value function
- Compute policy ratios: exp(new_log_probs - old_log_probs)
- Compute Surrogate Loss:
- surr1 = policy_ratio * advantages
- surr2 = clip(policy_ratio, 1 - eps_clip, 1 + eps_clip) * advantages
- Minimize loss: -min(surr1, surr2) + 0.5 * MSELoss(state_values, discounted_rewards) - 0.02 * entropy
- Backpropagation:
- Zero gradients
- Check for NaN values in loss or logits
- Clip gradients to max_norm=0.5
- Perform optimizer step
- Update old policy to match the new policy
Early_Stopping_and_Checkpoints:
- If early stopping criterion is met:
- Stop training
- Optionally save the model
- Periodically save checkpoints after every 100 episodes
Evaluate_Method:
- Input: state, action
- Calculate:
- state_value from value network
- action distribution from policy network
- Return:
- action log probability, state value, entropy of action distribution
Step I : Creation of Actor Critic Network
This ActorCritic
class is part of the architecture used in the Proximal Policy Optimization (PPO) algorithm. In PPO, two types of neural networks are used: the policy network (Actor) and the value network (Critic). These networks help the agent decide which actions to take (Actor) and estimate the quality of states (Critic). Let’s break this down:
ActorCritic Class Overview
- Convolutional Layers: The network starts with a convolutional layer (
conv1
) to process grid-like inputs, such as images or matrices (common in environments like games or simulations). This layer helps the network capture spatial dependencies and patterns from the input. - Fully Connected Layers: After the convolutional layers, the output is flattened and passed through fully connected (FC) layers (
fc1
,fc2
), which process high-level abstract features. - Policy (
pi
) and Value (v
) Networks: The network has two output heads:- Policy head (
fc_pi
): Outputs action probabilities. - Value head (
fc_v
): Outputs the value of the current state (used to estimate future rewards).
- Policy head (
Two Networks in PPO: Policy Network (Actor) and Value Network (Critic)
- Policy Network (Actor):
- Function: The actor decides which action to take based on the current state.
- How it works: The function
pi()
is responsible for the policy. It processes the input state through the convolutional and fully connected layers, producing a vector that represents action logits (unnormalized probabilities). TheCategorical(logits=x)
turns these logits into a probability distribution over actions. The agent samples from this distribution to select an action. - Significance in PPO: The policy network defines how the agent interacts with the environment, and it is trained to maximize the total expected reward while ensuring smooth updates to avoid large deviations from the previous policy (using PPO’s clipping objective).
- Value Network (Critic):
- Function: The critic estimates the expected value (future reward) of the current state.
- How it works: The function
v()
processes the input state through the same convolutional and fully connected layers but outputs a scalar representing the value of the state. This value is used to compute the advantage, which helps reduce the variance in the policy gradient updates, making the learning process more stable. - Significance in PPO: The value network provides an estimate of how “good” the current state is. The advantage is the difference between the actual reward and the estimated value. This is important because it guides the policy updates by indicating how much better or worse an action performed compared to the expectation.
How They Work Together in PPO
- In PPO, the policy network (Actor) is trained to improve the agent’s decisions, while the value network (Critic) helps stabilize this training by providing estimates of expected rewards
- The agent interacts with the environment using the old policy (stored in
policy_old
), collecting experience. Then, during the training process:- The critic evaluates the collected experiences, helping to estimate how much better or worse each action was compared to what was expected (advantage).
- The actor updates the policy using this feedback, improving the decision-making process to maximize rewards while keeping changes to the policy constrained (using PPO’s clipping mechanism).
By using both networks, PPO effectively balances exploration and exploitation: the policy network learns which actions to take, while the value network stabilizes learning by predicting future rewards more accurately.
The Code
# Define the ActorCritic class with convolutional layers
class ActorCritic(nn.Module):
def __init__(self, height=10, width=10, hidden_dim=128, action_dim=4):
super(ActorCritic, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3)
init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
# Calculate the size of the output from the convolutional layers
def conv2d_size_out(size, kernel_size=3, stride=1, padding=0):
return (size - (kernel_size - 1) - 1 + 2 * padding) // stride + 1
convw = conv2d_size_out(width)
convh = conv2d_size_out(height)
self.linear_input_size = convw * convh * 3 # 3 is the number of channels after conv2
# self.linear_input_size = 432
print(f"Linear input size: {self.linear_input_size}")
# Fully connected layers
self.fc1 = nn.Linear(self.linear_input_size, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc_pi = nn.Linear(hidden_dim, action_dim)
self.fc_v = nn.Linear(hidden_dim, 1)
# Optimizer
self.optimizer = optim.Adam(self.parameters(), lr=0.0002)
# LeakyReLU activation function
self.elu = nn.GELU()
def swish(self,x):
return x*torch.sigmoid(x)
def pi(self, x):
x = self.swish(self.conv1(x))
x = x.reshape(x.size(0),self.linear_input_size)
x = torch.tanh(self.fc1(x))
x = torch.tanh(self.fc2(x))
x = self.fc_pi(x)
return Categorical(logits=x)
def v(self, x):
x = self.swish(self.conv1(x))
# print("*************************************")
x = x.reshape(x.size(0),self.linear_input_size)
x = torch.tanh(self.fc1(x))
x = torch.tanh(self.fc2(x))
v = self.fc_v(x)
return v
Step II: Experience Collection
In the train
method:
- The agent interacts with the environment by choosing actions according to the old policy (
policy_old
) and stores the experience(state, action, log probability, reward, done)
in a buffer (self.memory
). - States are normalized before being passed to the network.
- After the interaction, the agent stores this trajectory in a memory buffer.
Step III: PPO Optimization
The core of the PPO algorithm is in the update
method, which is executed periodically:
- Discounted Reward Calculation:
- Compute discounted rewards by traversing the rewards from the end of the episode to the beginning, with a decay factor (
gamma
), adjusting for terminal states. - Normalize the discounted rewards for stable learning.
- Compute discounted rewards by traversing the rewards from the end of the episode to the beginning, with a decay factor (
- Policy Update:
- The policy is updated for a fixed number of epochs (
K_epochs
), during which:- Action log probabilities (
logprobs
) are recalculated using the current policy (policy
). - State values are predicted using the value function.
- Policy Ratio: The ratio of new policy probabilities (
logprobs
) to old policy probabilities (old_logprobs
) is computed. - Surrogate loss:
- The core idea of PPO is to ensure that the new policy doesn’t deviate too much from the old policy. Two terms,
surr1
andsurr2
, are computed: surr1
: The advantage weighted by the policy ratio.surr2
: The same but clipped within the range[1 - eps_clip, 1 + eps_clip]
to prevent excessive updates.- The final loss is the minimum of these two terms to limit how much the policy can change at each step.
- The core idea of PPO is to ensure that the new policy doesn’t deviate too much from the old policy. Two terms,
- Value loss: The difference between the predicted state values and the computed discounted rewards is also minimized.
- Entropy bonus: A small entropy term is subtracted from the loss to encourage exploration.
- Action log probabilities (
- The policy is updated for a fixed number of epochs (
- Gradient Clipping and Backpropagation:
- Gradients are computed, and backpropagation is performed, followed by gradient clipping (
max_norm=0.5
) to stabilize training. - The optimizer updates the network parameters.
- Gradients are computed, and backpropagation is performed, followed by gradient clipping (
- Policy Update:
- After updating the policy network (
self.policy
), the old policy is synchronized with the new policy (self.policy_old.load_state_dict(self.policy.state_dict())
).
- After updating the policy network (
Step IV: Early Stopping & Checkpointing
The training process can be stopped early if a stopping criterion is met (via early_stopping
).The model is periodically saved, and checkpointing is supported, so you can save the policy and rewards at specific intervals.
Summary of Key Components in PPO:
- Clip Objective: The loss function penalizes large changes in the policy between updates, ensuring small, stable policy changes.
- Multiple Epochs: The policy is updated using several epochs to improve efficiency, without requiring new data from the environment at each step.
- Entropy Regularization: Encourages exploration by adding a penalty term for low-entropy (highly deterministic) action distributions.
- Value Function: The critic’s value function helps in reducing variance in the reward estimate, making the updates more stable.
This PPO agent uses convolutional layers in the policy to process spatial inputs like images or grids, which is useful for environments like Snake.
The PPO Class implementing concepts described above
# Define the PPOAgent class
class PPOAgent:
def __init__(self, height, width, action_dim=4, buffer_size=10000, gamma=0.99,
K_epochs=4, eps_clip=0.2, hidden_dim=128, device=None):
self.policy = ActorCritic(height, width, hidden_dim, action_dim)
self.policy_old = ActorCritic(height, width, hidden_dim, action_dim)
self.policy_old.load_state_dict(self.policy.state_dict())
self.optimizer = self.policy.optimizer
self.MseLoss = nn.MSELoss()
self.memory = deque(maxlen=buffer_size)
self.gamma = gamma
self.K_epochs = K_epochs
self.eps_clip = eps_clip
self.device = device
self.rewards = []
def update(self):
states, actions, logprobs, rewards, is_terminals = zip(*self.memory)
discounted_rewards = []
discounted_reward = 0
for reward, is_terminal in zip(reversed(rewards), reversed(is_terminals)):
if is_terminal:
discounted_reward = 0
discounted_reward = reward + (self.gamma * discounted_reward)
discounted_rewards.insert(0, discounted_reward)
discounted_rewards = torch.tensor(discounted_rewards, dtype=torch.float32)
# print(f"Discounted Rewards: {discounted_rewards}")
discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-7)
# discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-7)
old_states = torch.cat(states).detach()
old_actions = torch.cat(actions).detach()
old_logprobs = torch.cat(logprobs).detach()
for _ in range(self.K_epochs):
logprobs, state_values, dist_entropy = self.evaluate(old_states, old_actions)
ratios = torch.exp(logprobs - old_logprobs.detach())
advantages = discounted_rewards - state_values.detach()
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
# loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, discounted_rewards) - 0.01 * dist_entropy
loss = -torch.min(surr1, surr2) + 0.5 * self.MseLoss(state_values, discounted_rewards) - 0.02 * dist_entropy # Example adjustment
self.optimizer.zero_grad()
loss.mean().backward()
if torch.isnan(loss).any():
print("NaN detected in loss!")
raise ValueError("NaN in loss detected.")
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), max_norm=0.5)
self.optimizer.step()
self.policy_old.load_state_dict(self.policy.state_dict())
def evaluate(self, state, action):
state_value = self.policy.v(state)
dist = self.policy.pi(state)
if torch.isnan(dist.logits).any():
print("NaN detected in logits!")
dist.logits = torch.where(torch.isnan(dist.logits), torch.zeros_like(dist.logits), dist.logits)
action_logprobs = dist.log_prob(action)
dist_entropy = dist.entropy()
return action_logprobs, torch.squeeze(state_value), dist_entropy
def save(self, filename):
checkpoint = {
'model_state_dict': self.policy.state_dict(),
'rewards': self.rewards
}
torch.save(checkpoint, filename)
print(f"Model and rewards saved to {filename}")
def load(self, filename):
checkpoint = torch.load(filename, map_location=self.device)
self.policy.load_state_dict(checkpoint['model_state_dict'])
self.policy_old.load_state_dict(self.policy.state_dict())
self.rewards = checkpoint.get('rewards', [])
print(f"Model and rewards loaded from {filename}")
def normalize_state(self, state):
return (state - np.mean(state)) / (np.std(state) + 1e-8)
def train(self, env, num_episodes, early_stopping=None, checkpoint_path=None):
for episode in range(1, num_episodes + 1):
total_reward = 0
state = env.reset()
state = self.normalize_state(state)
done = False
while not done:
state_tensor = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0) # Add channel dimension
dist = self.policy_old.pi(state_tensor)
action = dist.sample()
next_state, reward, done, _ = env.step(action.item())
next_state = self.normalize_state(next_state)
self.memory.append((state_tensor, action, dist.log_prob(action), reward, done))
state = next_state
total_reward += reward
if done:
print(f"Episode: {episode} Reward: {total_reward}")
break
# print(self.memory)
# print(len(self.memory))
if len(self.memory)>10:
print("Updating")
self.update()
self.memory.clear()
self.rewards.append(total_reward)
if early_stopping and early_stopping(self.rewards):
print("Early stopping criterion met")
if checkpoint_path:
self.save(checkpoint_path)
break
if (episode) % 100 == 0:
self.save(checkpoint_path)
env.close()
def test(self, env, num_episodes=10):
for episode in range(num_episodes):
state = env.reset()
state = self.normalize_state(state)
done = False
total_reward = 0
while not done:
state_tensor = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0) # Add channel dimension
dist = self.policy_old.pi(state_tensor)
action = dist.sample()
state, reward, done, _ = env.step(action.item())
state = self.normalize_state(state)
total_reward += reward
print(f"Episode {episode + 1}: Total Reward: {total_reward}")
self.rewards.append(total_reward)
env.close()
def plot(self, plot_path):
data = self.rewards
# Calculate the moving average
window_size = 1000
moving_avg = pd.Series(data).rolling(window=window_size).mean()
# Plotting
plt.figure(figsize=(10, 6))
# Plot the moving average line
sns.lineplot(data=moving_avg, color='red')
# Shade the area around the moving average line to represent the range of values
plt.fill_between(range(len(moving_avg)),
moving_avg - np.std(data),
moving_avg + np.std(data),
color='blue', alpha=0.2)
plt.xlabel('Episodes')
plt.ylabel('Rewards')
plt.title('Moving Average of Rewards')
plt.grid(True)
plt.tight_layout()
# Save the plot as a PNG file
plt.savefig(plot_path)
# Show the plot
plt.show()
Git Repository
- Implementation of PPO algorithm:
- Training the PPO agent to Play snake:
- Testing The trained Agent:
- The Trained agent:
- The Snake Env: