PyTorch
SARSA in the Wind

SARSA in the Wind

We will use SARSA algorithm to find the optimal policy so that our agent can navigate in windy world.

SARSA

State–action–reward–state–action (SARSA) is an algorithm for learning a Markov decision process policy, used in the reinforcement learning area of machine learning.

SARSA focuses on state-action values.

It updates the Q-function based on the following equation:

Q(s,a) = Q(s,a) + α (r + γ Q(s’,a’) – Q(s,a))

Here s’ is the resulting state after taking tha action ,a, in state s;

r is the associated reward;

α is the learning rate

and γ is the discount factor.

Windy GridWorld

Windy Gridworld is a grid problem with a 7 * 10 board, which is displayed as follows:

An agent makes a move up, right, down, and left at a step. Tile 30 is the starting point for
the agent, and tile 37 is the winning point where an episode will end if it is reached. Each
step the agent takes incurs a -1 reward.

The complexity in this environment is that there is extra wind force in columns 4 to 9.
Moving from tiles on those columns, the agent will experience an extra push upward. The
wind force in the seventh and eighth columns is 1, and the wind force in the fourth, fifth,
sixth, and ninth columns is 2. For example, if the agent tries to move right from state 43,
they will land in state 34; if the agent tries to move left from state 48, they will land in state
37; if the agent tries to move up from state 67, they will land in state 37 as the agent receives
an additional 2-unit force upward; if the agent tries to move down from state 27, they will
land in state 17, as the 2 extra force upward offsets 1 downward.

Developing the Windy Gridworld environment:

#Import the basic libraries
import numpy as np
import sys
from gym.envs.toy_text import discrete
#Define actions
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3
#The gridworld calss
class WindyGridworldEnv(discrete.DiscreteEnv):
  def __init__(self):
    self.shape = (7,10)
    nS = self.shape[0]*self.shape[1]
    nA = 4
    #Winds locations
    winds = np.zeros(self.shape)
    winds[:,[3,4,5,8]]=1
    winds[:,[6,7]]=2
    self.goal = (3,7)
    #Calcualte transition probabiliites and rewards
    P = {}
    for s in range(nS):
      position = np.unravel_index(s,self.shape)
      #print(position)
      P[s] = {a:[] for a in range(nA)}
      P[s][UP] = self._calculate_transition_prob(position,
                                                 [-1,0],winds)
      P[s][RIGHT] = self._calculate_transition_prob(position,
                                                 [0,1],winds)
      P[s][DOWN] = self._calculate_transition_prob(position,[1,0],
                                                   winds)
      P[s][LEFT] = self._calculate_transition_prob(position,[0,-1],winds)
      #print(P)

    #Calculate initial state distribution
    #We start in state (3,0)
    isd = np.zeros(nS)
    isd[np.ravel_multi_index((3,0),self.shape)] = 1.0
    super(WindyGridworldEnv,self).__init__(nS,nA,P,isd)

  def _calculate_transition_prob(self,current,delta,winds):
    """
    Determine the outcome for an action.
    Transition Prob is always 1.0

    @param current: (row,col), current position on the grid
    @param delta: Changein position for transition
    @param winds: Wind effect
    @return: (1.0,new_state,reward,is_done)
    """
    new_position = np.array(current) + np.array(delta) + np.array([-1,0])*winds[tuple(current)]
    new_position = self._limit_coordinates(new_position).astype(int)
    new_state = np.ravel_multi_index(tuple(new_position),self.shape)
    is_done = tuple(new_position) == self.goal
    return [(1.0,new_state,-1.0,is_done)]

  def _limit_coordinates(self,coord):
    coord[0] = min(coord[0],self.shape[0] -1)
    coord[0] = max(coord[0],0)
    coord[1] = min(coord[1],self.shape[1]-1)
    coord[1] = max(coord[1],0)

    return coord

  def render(self):
    # x represents the agent's current position.
    # T is the foal tile
    # the remaining tiles are denoted as o
    outfile = sys.stdout
    for s in range(self.nS):
      position = np.unravel_index(s,self.shape)
      if self.s == s:
        output = 'x'
      elif position == self.goal:
        output = "T"
      else:
        output = "o"

      if position[1] == 0:
        output = output.lstrip()
      if position[1] == self.shape[1]-1:
        output = output.rstrip()
        output += "\n"
      outfile.write(output)

    outfile.write("\n")


Instantiate an object of Windy Gridworld and playing with it

env = WindyGridworldEnv()
env.reset()
env.render()
print(env.step(UP))
env.render()
env.s = 43
env.render()
env.step(RIGHT)
env.render()
env.s = 48
env.render()
env.step(LEFT)
env.render()
env.s = 47
env.render()
env.step(LEFT)
env.render()

The SARSA Function

import torch
def gen_epsilon_greedy_policy(n_action,epsilon):
  def policy_function(state,Q):
    probs = torch.ones(n_action)*epsilon/n_action
    best_action = torch.argmax(Q[state]).item()
    probs[best_action] += 1.0 - epsilon
    action = torch.multinomial(probs,1).item()
    return action
  return policy_function
n_episode = 500
length_episode = [0]*n_episode
total_reward_episode = [0]*n_episode
from collections import defaultdict
def sarsa(env,gamma,n_episode,alpha):
  """
  Obtaion the optimal policy with on-policy SARSA algorithm
  @param env: OpenAI Gym environment
  @param gamma: discount factor
  @param n_episode: number of episodes
  @return : The optimal Q-function and the optimal policy
  """
  n_action = env.action_space.n
  Q = defaultdict(lambda: torch.zeros(n_action))
  for episode in range(n_episode):
    state = env.reset()
    is_done = False
    action = epsilon_greedy_policy(state,Q)
    while not is_done:
      next_state,reward, is_done, info = env.step(action)
      next_action = epsilon_greedy_policy(next_state,Q)

      
      td_delta = reward + gamma * Q[next_state][next_action] - Q[state][action]
      Q[state][action] += alpha * td_delta
      length_episode[episode] += 1
      total_reward_episode[episode] += reward
      if is_done:
        break
      state = next_state
      action = next_action

  policy = {}
  for state,actions in Q.items():
    policy[state] = torch.argmax(actions).item()

  return Q,policy 

Finding The Optimal Policy with SARSA

#The hyperparameters
gamma = 1
alpha = 0.4
epsilon = 0.1
epsilon_greedy_policy = gen_epsilon_greedy_policy(env.action_space.n,epsilon)
optimal_Q, optimal_policy = sarsa(env,gamma,n_episode,alpha)
print('The optimal policy: \n',optimal_policy)

The optimal policy: {30: 1, 20: 1, 10: 2, 21: 1, 11: 1, 12: 1, 2: 0, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 2, 19: 2, 29: 2, 18: 2, 13: 1, 1: 1, 22: 1, 23: 1, 0: 2, 39: 2, 28: 2, 32: 1, 31: 1, 33: 1, 40: 2, 14: 1, 41: 1, 42: 1, 43: 1, 38: 2, 50: 2, 51: 1, 52: 1, 53: 3, 24: 1, 34: 1, 61: 1, 62: 1, 44: 1, 15: 1, 17: 0, 49: 2, 59: 3, 48: 3, 69: 3, 68: 2, 27: 0, 60: 1, 63: 1, 25: 1, 37: 0, 54: 1, 45: 1, 16: 2, 35: 1, 58: 3, 47: 2, 57: 2, 36: 2, 26: 2}

Plotting Episode Length and Reward

import matplotlib.pyplot as plt
plt.figure(figsize=(10,7))
plt.plot(length_episode)
plt.title('Episode length over time')
plt.xlabel('Episode')
plt.ylabel('Length')
plt.show()
plt.figure(figsize=(10,7))
plt.plot(total_reward_episode)
plt.title('Episode reward over time')
plt.xlabel('Episode')
plt.ylabel('Total reward')
plt.show()

Agent navigating the Windy Gridworld with policy learnt using SARSA