PyTorch
Master Snake Game AI with PPO: Step-by-Step Guide (Part I)

Master Snake Game AI with PPO: Step-by-Step Guide (Part I)

This is a two part tutorial, we’ll explore how to create a Snake Game in first part and develop an AI using Proximal Policy Optimization (PPO) to master the gameplay in second part.

Creating a Snake game environment

The Snake game is a classic and simple arcade game that has captivated players for decades. The objective is straight forward: control a snake as it navigates a grid; consuming fruit to grow longer while avoiding collisions with itself and the grid boundaries. The challenge lies in maneuvering the increasingly lengthy snake, as each piece of fruit adds to its body, making it more difficult to avoid collissions.

Game Rules:

Objective:

The primary goal is to guide the snake toward the fruit scattered across the grid. Each time the snake eats a fruit, it grows longer by one unit.

Movement:

The snake can move in four directions: up, down, left or right. The player can control the snake’s direction, but it cannot move backward. (e.g if moving right, it cannot immediately move left).

Grid:

The snake moves within a grid, which acts as the gameboard.

Scoring System

Fruit Consumption:

The snake earns point by consuming fruit. Each fruit consumed typically awards a fixed number of points. In this environment, eating a fruit awards 10 points.

Penalty for Inactivity:

To encourage efficient gameplay, the snake incurs a small penalty of -0.01 point for each step taken without eating a fruit. This penalty ensures that the player is motivated to seek out fruit rather than aimlessly moving around.

Self Collision Penalty:

If the snake collides with itself, the game ends and a significant penalty of -5 points is applied.

Boundary Collision Penalty:

Similarly, if the snake collides with the grid’s boundary the game ends with a -5 points penalty.

Additionally, if the player takes too many steps without eating fruit (in this case, more than 100 steps), the game automatically ends to prevent stalling and adds an additional penalty of -5 to the score.

Lets dive into code

Installation of basic library

pip install gym pygame numpy

Step I: Initialize the Environment

We begin by setting up the basic structure of a Gym environment. This includes inheriting from the gym.Env class and defining the necessary methods: __init__, reset(), step(), and render().

import gym
from gym import spaces
import numpy as np
import pygame
import random

Defining Constants: We’ll define the colors and window size for our game.

BROWN = (139, 69, 19)  # Background color
SNAKE_COLOR = (0, 255, 0)  # Snake color
FRUIT_COLOR = (255, 0, 0)  # Fruit color
HEAD_COLOR = (0, 0, 255)  # Head of the snake

Step 2: Create the SnakeEnvWithPenalty Class

This class will manage everything related to the environment, including snake movement, fruit generation, rewards, and rendering.

class SnakeEnvWithPenalty(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, width=500, height=500, rows=7, cols=7):
        super(SnakeEnvWithPenalty, self).__init__()

        # Window size and grid dimensions
        self.WIDTH = width
        self.HEIGHT = height
        self.ROWS = rows
        self.COLS = cols
        self.CELL_SIZE = self.WIDTH // self.COLS  # Size of each cell

        # Gym spaces
        self.action_space = spaces.Discrete(4)  # 4 directions: UP, DOWN, LEFT, RIGHT
        self.observation_space = spaces.Box(low=0, high=7, shape=(self.ROWS, self.COLS), dtype=np.uint8)

        # Initialize Pygame
        self.window = pygame.display.set_mode((self.WIDTH, self.HEIGHT))
        pygame.display.set_caption("Snake Game with Penalty")
        self.clock = pygame.time.Clock()

        self.reset()

Step 3: Reset the Environment

The reset() function will reset the snake to a starting position and generate a fruit randomly on the grid.

def reset(self):
    self.snake_positions = self.generate_snake()
    self.fruit_position = self.generate_fruit(self.snake_positions)
    self.direction = 'RIGHT'
    self.score = 0
    self.steps_without_fruit = 0  # Counter for steps without eating fruit

    return self._get_observation()

Generate Snake: We initialize the snake in either a horizontal or vertical orientation at random positions.

def generate_snake(self):
    is_horizontal = random.choice([True, False])
    if is_horizontal:
        row = random.randint(0, self.ROWS - 1)
        col_start = random.randint(0, self.COLS - 2)
        snake_positions = [(row, col_start), (row, col_start + 1)]
    else:
        col = random.randint(0, self.COLS - 1)
        row_start = random.randint(0, self.ROWS - 2)
        snake_positions = [(row_start, col), (row_start + 1, col)]
    return snake_positions

Generate Fruit: The fruit is placed randomly on the grid but not where the snake currently is.

def generate_fruit(self, snake_positions):
    while True:
        fruit_position = (random.randint(0, self.ROWS - 1), random.randint(0, self.COLS - 1))
        if fruit_position not in snake_positions:
            return fruit_position

Step 4: Implement the step() Function

This function will handle the logic for each action taken by the snake, updating its position, checking for collisions, and calculating the reward.

def step(self, action):
    # Translate action to direction
    if action == 0 and self.direction != 'DOWN':
        self.direction = 'UP'
    elif action == 1 and self.direction != 'UP':
        self.direction = 'DOWN'
    elif action == 2 and self.direction != 'RIGHT':
        self.direction = 'LEFT'
    elif action == 3 and self.direction != 'LEFT':
        self.direction = 'RIGHT'

    # Move the snake
    self.snake_positions, self.fruit_position, points, game_over, self_collision = self.move_snake(self.snake_positions, self.direction, self.fruit_position)
    self.score += points

    reward = points - 0.01  # Penalty for each step
    if self_collision:
        reward = -5  # Negative reward for self-collision
    done = game_over

    observation = self._get_observation()

    # Additional penalty if the snake takes too many steps without eating fruit
    self.steps_without_fruit += 1
    if self.steps_without_fruit > 100:
        reward -= 1
        done = True

    return observation, reward, done, {}

Step 5: Render the Game

The render() method draws the snake, fruit, and grid on the screen using Pygame.

def render(self, mode='human', close=False):
    self.window.fill((255, 255, 255))
    self.draw_grid()
    self.draw_snake(self.snake_positions)
    self.draw_fruit(self.fruit_position)

    # Display score
    font = pygame.font.Font(None, 36)
    text = font.render(f"Score: {self.score}", True, (0, 0, 0))
    self.window.blit(text, (10, 10))

    pygame.display.flip()
    self.clock.tick(15)  # Set frame rate

The draw_grid(), draw_snake(), and draw_fruit() functions handle the Pygame drawing logic:

def draw_grid(self):
    for row in range(self.ROWS):
        for col in range(self.COLS):
            rect = pygame.Rect(col * self.CELL_SIZE, row * self.CELL_SIZE, self.CELL_SIZE, self.CELL_SIZE)
            pygame.draw.rect(self.window, BROWN, rect)
            pygame.draw.rect(self.window, (0, 0, 0), rect, 1)

def draw_snake(self, snake_positions):
    for i, pos in enumerate(snake_positions):
        rect = pygame.Rect(pos[1] * self.CELL_SIZE, pos[0] * self.CELL_SIZE, self.CELL_SIZE, self.CELL_SIZE)
        pygame.draw.rect(self.window, SNAKE_COLOR, rect)
        if i == 0:  # Head of the snake
            center = (pos[1] * self.CELL_SIZE + self.CELL_SIZE // 2, pos[0] * self.CELL_SIZE + self.CELL_SIZE // 2)
            pygame.draw.circle(self.window, HEAD_COLOR, center, self.CELL_SIZE // 4)

def draw_fruit(self, fruit_position):
    rect = pygame.Rect(fruit_position[1] * self.CELL_SIZE, fruit_position[0] * self.CELL_SIZE, self.CELL_SIZE, self.CELL_SIZE)
    pygame.draw.rect(self.window, FRUIT_COLOR, rect)

Step 6: Moving the Snake

In the move_snake() function, the snake’s head position is updated based on the direction, and we check for collisions with the walls or the snake’s own body.

def move_snake(self, snake_positions, direction, fruit_position):
    head_x, head_y = snake_positions[0]
    if direction == 'UP':
        new_head = (head_x - 1, head_y)
    elif direction == 'DOWN':
        new_head = (head_x + 1, head_y)
    elif direction == 'LEFT':
        new_head = (head_x, head_y - 1)
    elif direction == 'RIGHT':
        new_head = (head_x, head_y + 1)

    self_collision = False
    if not (0 <= new_head[0] < self.ROWS and 0 <= new_head[1] < self.COLS):
        return snake_positions, fruit_position, -5, True, self_collision  # Boundary collision

    if new_head in snake_positions:
        self_collision = True
        return snake_positions, fruit_position, -5, True, self_collision  # Self-collision

    if new_head == fruit_position:
        fruit_position = self.generate_fruit(snake_positions)
        self.steps_without_fruit = 0  # Reset steps counter
        return [new_head] + snake_positions, fruit_position, 10, False, self_collision  # Grow snake
    else:
        return [new_head] + snake_positions[:-1], fruit_position, 0, False, self_collision

Step 7: Getting the Observation

The _get_observation() function returns a grid representation of the environment, where the snake is represented by 1 for its body, 7 for the head, and 2 for the fruit.

def _get_observation(self):
    grid = np.full((self.ROWS, self.COLS), 0, dtype=np.uint8)

    for pos in self.snake_positions[1:]:
        grid[pos[0], pos[1]] = 1  # Snake body
    head_pos = self.snake_positions[0]
    grid[head_pos[0], head_pos[1]] = 7  # Snake head

    grid[self.fruit_position[0], self.fruit_position[1]] = 2  # Fruit

    return grid

Conclusion

You now have a fully functioning snake game environment, built from scratch using Gym and Pygame. This environment can be integrated into any reinforcement learning algorithm, such as DQN, PPO, or A2C.

In next blog, we will learn to build PPO agent to play with this snake.

Code