data:image/s3,"s3://crabby-images/6f8bd/6f8bde8e439f44885df001d63583645babb3e301" alt="Prioritized Experience Replay Using PyTorch Prioritized Experience Replay Using PyTorch"
Prioritized Experience Replay Using PyTorch
Prioritized Experience Replay replay important transitions more frequently and therefore learn more efficiently. In this algorithm, the magnitude of the temporal difference error is used as a measure the experience which resulted into high expected learning rate.
“Greedy TD-Error Prioritization”:
This algorithm stores the last encountered TD error along with each transition in the replay memory. The transition with the largest absolute TD error is replayed from the memory. A Q-learning update is applied to this transition, which updates the weights in proportion to the TD error. New transition which arrives without TD error were put at maximal priority so that it can be guarantee that all experience are seen at least once.
Stochastic Prioritization
There are few problems with “Greedy TD-error Prioritization”.
The first problem is, to avoid expensive sweeps over the entire replay memory, TD errors are only updated for the transitions that are replayed. One consequence is that transitions that have a low TD error on first visit may not be replayed for a long time. “Greedy TD-error Prioritization” is also sensitive to noise spikes, which can be exacerbated by bootstrapping, where approximation errors appear as another source of noise.Finally, greedy prioritization focuses on a small subset of the experience: errors shrink slowly, especially when using function approximation, meaning that the initially high error transitions get replayed frequently. This lack of diversity make the system prone to overfitting.
To overcome these issues, a stochastic sampling method that interpolates between pure greedy prioritization and uniform random sampling was introduced. Which ensure that the probability of being sampled is monotonic in a transition’s priority, while guaranteeing a non-zero probability even for the lowest-priority transition.
The probability of sampling transition i is defined as:
Where, Pi >0 is the priority of transition i. The exponent 𝛼 determines how much prioritization is used, with 𝛼=0 corresponding to the uniform case.
Annealing The Bias
Prioritized replay introduces bias because it changes this distribution in an uncontrolled fashion and therefore changes the solution that estimates will converge. This is corrected by using importance-sampling (IS) weights that fully compensates for the non-uniform probabilities P(i) if β=1. These weights can be folded into Q-learning update by using wiδi instead of δi. For stability reasons, weights were normalized by 1/maxiwi so that they only scale the update downwards.
Implementation using Pytorch
We will start by building class of Prioritized Replay Buffer.
Prioritized Replay Buffer
The PrioritizedReplayBuffer
class implements a prioritized experience replay buffer, which is used in reinforcement learning to store and sample transitions (state, action, reward, next_state, done) with prioritization based on their importance. Here’s a detailed explanation of each part of the class:
class PrioritizedReplayBuffer:
def __init__(self, capacity, alpha=0.6):
self.capacity = capacity
self.alpha = alpha
self.buffer = []
self.priorities = np.zeros((capacity,), dtype=np.float32)
self.position = 0
capacity
: Maximum number of transitions the buffer can hold.
alpha
: Controls the level of prioritization. When alpha
= 0, it behaves like a uniform random sampling; when alpha
= 1, it behaves fully proportional to the TD error.
buffer
: List to store the transitions.
priorities
: Array to store the priority values of the transitions.
position
: Tracks the position where the next transition will be inserted.
Adding Transition
def add(self, state, action, reward, next_state, done):
max_prio = self.priorities.max() if self.buffer else 1.0
if len(self.buffer) < self.capacity:
self.buffer.append((state, action, reward, next_state, done))
else:
self.buffer[self.position] = (state, action, reward, next_state, done)
self.priorities[self.position] = max_prio
self.position = (self.position + 1) % self.capacity
max_prio
: Maximum priority value in the buffer. If the buffer is empty, it defaults to 1.0.
Adding the transition: If the buffer is not full, the transition is appended to the buffer. If the buffer is full, the transition at self.position
is overwritten.
Setting priority: The new transition is assigned the maximum priority.
Updating position: The position pointer is updated in a circular manner using modulo operation.
Sampling Transitions
def sample(self, batch_size, beta=0.4):
if len(self.buffer) == self.capacity:
prios = self.priorities
else:
prios = self.priorities[:self.position]
probs = prios ** self.alpha
probs /= probs.sum()
indices = np.random.choice(len(self.buffer), batch_size, p=probs)
samples = [self.buffer[idx] for idx in indices]
total = len(self.buffer)
weights = (total * probs[indices]) ** (-beta)
weights /= weights.max()
weights = np.array(weights, dtype=np.float32)
batch = list(zip(*samples))
states = np.array(batch[0])
actions = np.array(batch[1])
rewards = np.array(batch[2])
next_states = np.array(batch[3])
dones = np.array(batch[4])
return states, actions, rewards, next_states, dones, indices, weights
Handling priorities: If the buffer is full, prios
contains all priorities. Otherwise, it contains only the priorities up to the current position.
Calculating probabilities: Priorities are raised to the power of alpha
to calculate sampling probabilities and then normalized.
Sampling indices: Transitions are sampled according to the calculated probabilities.
Computing weights: Importance-sampling weights are calculated and normalized. This compensates for the non-uniform probability of sampling.
Batch formation: The sampled transitions are grouped into batches.
Return: The method returns the batch of transitions, their indices, and importance-sampling weights.
Updating Priorities
def update_priorities(self, batch_indices, batch_priorities):
for idx, prio in zip(batch_indices, batch_priorities):
self.priorities[idx] = prio
Updating priorities: The method updates the priorities of the sampled transitions based on the given TD-errors |δi|.
The Neural Network
The DQN
class defines a Deep Q-Network (DQN) using PyTorch, which is a type of neural network used in reinforcement learning for approximating the Q-value function. The Q-value function predicts the expected future rewards for taking a given action in a given state. Here’s a detailed breakdown of the class:
Initialization
class DQN(nn.Module):
def __init__(self, input_dim, output_dim):
super(DQN, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, output_dim)
input_dim
: The number of input features, which corresponds to the dimensions of the state space.output_dim
: The number of output features, which corresponds to the number of possible actions.
The network consists of three fully connected (linear) layers:
self.fc1
: The first layer takes the input features and maps them to 128 hidden units.self.fc2
: The second layer maps the 128 hidden units from the first layer to another set of 128 hidden units.self.fc3
: The final layer maps the 128 hidden units to the output dimensions, which represent the Q-values for each action.
Forward Pass
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
- Input
x
: The input to the network, which is the state representation. - First Layer: The input
x
is passed through the first fully connected layerfc1
, followed by a ReLU activation function. - Second Layer: The output of the first layer is passed through the second fully connected layer
fc2
, followed by another ReLU activation function. - Third Layer: The output of the second layer is passed through the final fully connected layer
fc3
.
The network does not use an activation function after the final layer because the output represents the Q-values, which can be any real numbers.
The DQN Agent
The DQNAgent
class defines a reinforcement learning agent using a Deep Q-Network (DQN) and a prioritized replay buffer to interact with an environment, select actions, and train the policy network. Here’s a detailed breakdown of the class:
Initialization
class DQNAgent:
def __init__(self, env, buffer, batch_size=64, gamma=0.99, lr=1e-3, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=500):
self.env = env
self.buffer = buffer
self.batch_size = batch_size
self.gamma = gamma
self.lr = lr
self.epsilon_start = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
self.epsilon = epsilon_start
self.steps_done = 0
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.policy_net = DQN(env.observation_space.shape[0], env.action_space.n).to(self.device)
self.target_net = DQN(env.observation_space.shape[0], env.action_space.n).to(self.device)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.lr)
self.loss_fn = nn.MSELoss()
self.update_target()
env
: The environment the agent interacts with.
buffer
: The prioritized replay buffer for storing and sampling transitions.
batch_size
: Number of transitions sampled per training step.
gamma
: Discount factor for future rewards.
lr
: Learning rate for the optimizer.
epsilon_start
, epsilon_end
, epsilon_decay
: Parameters for the epsilon-greedy policy used for exploration.
device
: The device (CPU or GPU) for computation.
policy_net
: The main DQN used to select actions.
target_net
: A separate DQN used to compute target Q-values for stability.
optimizer
: Adam optimizer for training the policy network.
loss_fn
: Mean Squared Error loss function.
update_target
: Initializes the target network with the same weights as the policy network.
Updating the Target Network
def update_target(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
Copies the weights from the policy network to the target network.
The DQNAgent
class defines a reinforcement learning agent using a Deep Q-Network (DQN) and a prioritized replay buffer to interact with an environment, select actions, and train the policy network. Here’s a detailed breakdown of the class:
Action Selection
def select_action(self, state):
self.steps_done += 1
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp(-1. * self.steps_done / self.epsilon_decay)
if random.random() > self.epsilon:
with torch.no_grad():
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
return self.policy_net(state).argmax().item()
else:
return self.env.action_space.sample()
Uses an epsilon-greedy policy to select actions. The exploration rate (epsilon
) decays over time.If a random number is greater than epsilon
, it selects the action with the highest Q-value from the policy network.Otherwise, it selects a random action.
Training Step
def train_step(self):
if len(self.buffer.buffer) < self.batch_size:
return
states, actions, rewards, next_states, dones, indices, weights = self.buffer.sample(self.batch_size)
states = torch.FloatTensor(states).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
weights = torch.FloatTensor(weights).to(self.device)
q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
next_q_values = self.target_net(next_states).max(1)[0]
expected_q_values = rewards + self.gamma * next_q_values * (1 - dones)
loss = (weights * (q_values - expected_q_values.detach()).pow(2)).mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
priorities = (q_values - expected_q_values).abs().cpu().detach().numpy() + 1e-5
self.buffer.update_priorities(indices, priorities)
Samples a batch of transitions from the prioritized replay buffer.
Converts the sampled transitions to PyTorch tensors and moves them to the appropriate device.
Computes the Q-values for the current states and actions using the policy network.
Computes the target Q-values for the next states using the target network.
Calculates the expected Q-values using the Bellman equation.
Computes the loss, taking into account the importance-sampling weights.
Performs backpropagation to update the policy network.
Updates the priorities in the replay buffer based on the TD errors.
Saving and Loading the Model
def save(self, filename):
torch.save(self.policy_net.state_dict(), filename)
def load(self, filename, map_location=None):
self.policy_net.load_state_dict(torch.load(filename, map_location=map_location))
self.update_target()
save
: Saves the policy network’s state dictionary to a file.
load
: Loads the policy network’s state dictionary from a file and updates the target network.
To summarize : the DQNAgent
class combines a DQN model, a prioritized replay buffer, and an epsilon-greedy policy to interact with an environment, select actions, and learn from experiences. It includes methods for selecting actions, training the network, and saving/loading the model. The agent uses a target network for stability and updates priorities in the replay buffer to prioritize important transitions during training.
Initialization and Hyperparameters
env = gym.make("LunarLander-v2")
buffer = PrioritizedReplayBuffer(10000)
agent = DQNAgent(env, buffer)
num_episodes = 2000
target_update_freq = 10
rewards = []
path = 'per.pth'
try:
map_location = torch.device('cpu') if not torch.cuda.is_available() else None
agent = agent.load(path,map_location=map_location)
except FileNotFoundError:
print("No checkpoint found, starting from scratch.")
The Training Loop
This code snippet implements the training loop for a Deep Q-Network (DQN) agent in a reinforcement learning environment. Here’s a detailed explanation of each part.
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
for t in range(1000):
action = agent.select_action(state)
next_state, reward, done, _ = env.step(action)
buffer.add(state, action, reward, next_state, done)
state = next_state
total_reward += reward
agent.train_step()
if done:
break
rewards.append(total_reward)
print(f"Episode {episode}, Total Reward: {total_reward}")
if episode % target_update_freq == 0:
agent.update_target()
agent.save(path)
if episode > 100 and sum(rewards[-5:]) > 1000:
agent.save(path)
print(f'\nEnvironment solved in {episode} episodes!')
break
env.close()
Episode Loop:
- The loop runs for a specified number of episodes (
num_episodes
). - At the start of each episode, the environment is reset to obtain the initial state.
- The total reward for the episode is initialized to zero.
Step Loop:
- The inner loop runs for a maximum of 1000 steps (or until the episode terminates).
- Action Selection: The agent selects an action based on the current state using the
select_action
method. - Environment Step: The selected action is taken in the environment using
env.step(action)
, which returns the next state, reward, done flag, and additional info. - Store Transition: The transition (state, action, reward, next state, done) is added to the replay buffer.
- Update State and Reward: The current state is updated to the next state, and the total reward is incremented by the reward received.
- Training Step: The agent performs a training step using the
train_step
method. - Check for Episode Termination: If the episode is done (i.e., the environment returned
done=True
), the loop breaks.
Track Rewards:
- After the inner loop completes, the total reward for the episode is appended to the
rewards
list. - A message is printed showing the episode number and the total reward for that episode.
Target Network Update and Model Saving:
- Every
target_update_freq
episodes, the target network is updated usingagent.update_target()
. - The model is saved to a specified path using
agent.save(path)
.
Early Stopping:
- If the total reward for the last 5 episodes exceeds 1000 after 100 episodes, the model is saved, and a message is printed indicating that the environment is solved. The training loop is then terminated.
Environment Cleanup:
- After the training loop completes, the environment is closed using
env.close()
.
To summarize: This training loop repeatedly interacts with the environment, collects experiences, trains the DQN agent, and periodically updates the target network and saves the model. It also includes a mechanism for early stopping if the agent consistently achieves high rewards, indicating that it has learned an effective policy
Plotting the rewards
This code effectively visualizes the moving average of rewards over episodes during the training of a reinforcement learning agent. The moving average helps smooth out fluctuations and highlights the overall trend. The shaded area around the moving average provides a sense of the variability in the rewards.
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# Generate some sample data
np.random.seed(0)
data = rewards
# Calculate the moving average
window_size = 10
moving_avg = pd.Series(data).rolling(window=window_size).mean()
# Plotting
plt.figure(figsize=(10, 6))
# Plot the moving average line
sns.lineplot(data=moving_avg, color='red')
# Shade the area around the moving average line to represent the range of values
plt.fill_between(range(len(moving_avg)),
moving_avg - np.std(data),
moving_avg + np.std(data),
color='blue', alpha=0.2)
plt.xlabel('Episodes')
plt.ylabel('Rewards')
plt.title('Moving Average of Rewards')
plt.grid(True)
# Adjust layout to prevent overlapping elements
plt.tight_layout()
# Save the plot as a PNG file
plt.savefig('Episode_rewards.png')
# Show the plot
plt.show()