Master Snake Game AI with PPO: Step-by-Step Guide (Part I)
This is a two part tutorial, we’ll explore how to create a Snake Game in first part and develop an AI using Proximal Policy Optimization (PPO) to master the gameplay in second part.
Creating a Snake game environment
The Snake game is a classic and simple arcade game that has captivated players for decades. The objective is straight forward: control a snake as it navigates a grid; consuming fruit to grow longer while avoiding collisions with itself and the grid boundaries. The challenge lies in maneuvering the increasingly lengthy snake, as each piece of fruit adds to its body, making it more difficult to avoid collissions.
Game Rules:
Objective:
The primary goal is to guide the snake toward the fruit scattered across the grid. Each time the snake eats a fruit, it grows longer by one unit.
Movement:
The snake can move in four directions: up, down, left or right. The player can control the snake’s direction, but it cannot move backward. (e.g if moving right, it cannot immediately move left).
Grid:
The snake moves within a grid, which acts as the gameboard.
Scoring System
Fruit Consumption:
The snake earns point by consuming fruit. Each fruit consumed typically awards a fixed number of points. In this environment, eating a fruit awards 10 points.
Penalty for Inactivity:
To encourage efficient gameplay, the snake incurs a small penalty of -0.01 point for each step taken without eating a fruit. This penalty ensures that the player is motivated to seek out fruit rather than aimlessly moving around.
Self Collision Penalty:
If the snake collides with itself, the game ends and a significant penalty of -5 points is applied.
Boundary Collision Penalty:
Similarly, if the snake collides with the grid’s boundary the game ends with a -5 points penalty.
Additionally, if the player takes too many steps without eating fruit (in this case, more than 100 steps), the game automatically ends to prevent stalling and adds an additional penalty of -5 to the score.
Lets dive into code
Installation of basic library
pip install gym pygame numpy
Step I: Initialize the Environment
We begin by setting up the basic structure of a Gym environment. This includes inheriting from the gym.Env
class and defining the necessary methods: __init__
, reset()
, step()
, and render()
.
import gym
from gym import spaces
import numpy as np
import pygame
import random
Defining Constants: We’ll define the colors and window size for our game.
BROWN = (139, 69, 19) # Background color
SNAKE_COLOR = (0, 255, 0) # Snake color
FRUIT_COLOR = (255, 0, 0) # Fruit color
HEAD_COLOR = (0, 0, 255) # Head of the snake
Step 2: Create the SnakeEnvWithPenalty
Class
This class will manage everything related to the environment, including snake movement, fruit generation, rewards, and rendering.
class SnakeEnvWithPenalty(gym.Env):
metadata = {'render.modes': ['human']}
def __init__(self, width=500, height=500, rows=7, cols=7):
super(SnakeEnvWithPenalty, self).__init__()
# Window size and grid dimensions
self.WIDTH = width
self.HEIGHT = height
self.ROWS = rows
self.COLS = cols
self.CELL_SIZE = self.WIDTH // self.COLS # Size of each cell
# Gym spaces
self.action_space = spaces.Discrete(4) # 4 directions: UP, DOWN, LEFT, RIGHT
self.observation_space = spaces.Box(low=0, high=7, shape=(self.ROWS, self.COLS), dtype=np.uint8)
# Initialize Pygame
self.window = pygame.display.set_mode((self.WIDTH, self.HEIGHT))
pygame.display.set_caption("Snake Game with Penalty")
self.clock = pygame.time.Clock()
self.reset()
Step 3: Reset the Environment
The reset()
function will reset the snake to a starting position and generate a fruit randomly on the grid.
def reset(self):
self.snake_positions = self.generate_snake()
self.fruit_position = self.generate_fruit(self.snake_positions)
self.direction = 'RIGHT'
self.score = 0
self.steps_without_fruit = 0 # Counter for steps without eating fruit
return self._get_observation()
Generate Snake: We initialize the snake in either a horizontal or vertical orientation at random positions.
def generate_snake(self):
is_horizontal = random.choice([True, False])
if is_horizontal:
row = random.randint(0, self.ROWS - 1)
col_start = random.randint(0, self.COLS - 2)
snake_positions = [(row, col_start), (row, col_start + 1)]
else:
col = random.randint(0, self.COLS - 1)
row_start = random.randint(0, self.ROWS - 2)
snake_positions = [(row_start, col), (row_start + 1, col)]
return snake_positions
Generate Fruit: The fruit is placed randomly on the grid but not where the snake currently is.
def generate_fruit(self, snake_positions):
while True:
fruit_position = (random.randint(0, self.ROWS - 1), random.randint(0, self.COLS - 1))
if fruit_position not in snake_positions:
return fruit_position
Step 4: Implement the step()
Function
This function will handle the logic for each action taken by the snake, updating its position, checking for collisions, and calculating the reward.
def step(self, action):
# Translate action to direction
if action == 0 and self.direction != 'DOWN':
self.direction = 'UP'
elif action == 1 and self.direction != 'UP':
self.direction = 'DOWN'
elif action == 2 and self.direction != 'RIGHT':
self.direction = 'LEFT'
elif action == 3 and self.direction != 'LEFT':
self.direction = 'RIGHT'
# Move the snake
self.snake_positions, self.fruit_position, points, game_over, self_collision = self.move_snake(self.snake_positions, self.direction, self.fruit_position)
self.score += points
reward = points - 0.01 # Penalty for each step
if self_collision:
reward = -5 # Negative reward for self-collision
done = game_over
observation = self._get_observation()
# Additional penalty if the snake takes too many steps without eating fruit
self.steps_without_fruit += 1
if self.steps_without_fruit > 100:
reward -= 1
done = True
return observation, reward, done, {}
Step 5: Render the Game
The render()
method draws the snake, fruit, and grid on the screen using Pygame.
def render(self, mode='human', close=False):
self.window.fill((255, 255, 255))
self.draw_grid()
self.draw_snake(self.snake_positions)
self.draw_fruit(self.fruit_position)
# Display score
font = pygame.font.Font(None, 36)
text = font.render(f"Score: {self.score}", True, (0, 0, 0))
self.window.blit(text, (10, 10))
pygame.display.flip()
self.clock.tick(15) # Set frame rate
The draw_grid()
, draw_snake()
, and draw_fruit()
functions handle the Pygame drawing logic:
def draw_grid(self):
for row in range(self.ROWS):
for col in range(self.COLS):
rect = pygame.Rect(col * self.CELL_SIZE, row * self.CELL_SIZE, self.CELL_SIZE, self.CELL_SIZE)
pygame.draw.rect(self.window, BROWN, rect)
pygame.draw.rect(self.window, (0, 0, 0), rect, 1)
def draw_snake(self, snake_positions):
for i, pos in enumerate(snake_positions):
rect = pygame.Rect(pos[1] * self.CELL_SIZE, pos[0] * self.CELL_SIZE, self.CELL_SIZE, self.CELL_SIZE)
pygame.draw.rect(self.window, SNAKE_COLOR, rect)
if i == 0: # Head of the snake
center = (pos[1] * self.CELL_SIZE + self.CELL_SIZE // 2, pos[0] * self.CELL_SIZE + self.CELL_SIZE // 2)
pygame.draw.circle(self.window, HEAD_COLOR, center, self.CELL_SIZE // 4)
def draw_fruit(self, fruit_position):
rect = pygame.Rect(fruit_position[1] * self.CELL_SIZE, fruit_position[0] * self.CELL_SIZE, self.CELL_SIZE, self.CELL_SIZE)
pygame.draw.rect(self.window, FRUIT_COLOR, rect)
Step 6: Moving the Snake
In the move_snake()
function, the snake’s head position is updated based on the direction, and we check for collisions with the walls or the snake’s own body.
def move_snake(self, snake_positions, direction, fruit_position):
head_x, head_y = snake_positions[0]
if direction == 'UP':
new_head = (head_x - 1, head_y)
elif direction == 'DOWN':
new_head = (head_x + 1, head_y)
elif direction == 'LEFT':
new_head = (head_x, head_y - 1)
elif direction == 'RIGHT':
new_head = (head_x, head_y + 1)
self_collision = False
if not (0 <= new_head[0] < self.ROWS and 0 <= new_head[1] < self.COLS):
return snake_positions, fruit_position, -5, True, self_collision # Boundary collision
if new_head in snake_positions:
self_collision = True
return snake_positions, fruit_position, -5, True, self_collision # Self-collision
if new_head == fruit_position:
fruit_position = self.generate_fruit(snake_positions)
self.steps_without_fruit = 0 # Reset steps counter
return [new_head] + snake_positions, fruit_position, 10, False, self_collision # Grow snake
else:
return [new_head] + snake_positions[:-1], fruit_position, 0, False, self_collision
Step 7: Getting the Observation
The _get_observation()
function returns a grid representation of the environment, where the snake is represented by 1 for its body, 7 for the head, and 2 for the fruit.
def _get_observation(self):
grid = np.full((self.ROWS, self.COLS), 0, dtype=np.uint8)
for pos in self.snake_positions[1:]:
grid[pos[0], pos[1]] = 1 # Snake body
head_pos = self.snake_positions[0]
grid[head_pos[0], head_pos[1]] = 7 # Snake head
grid[self.fruit_position[0], self.fruit_position[1]] = 2 # Fruit
return grid
Conclusion
You now have a fully functioning snake game environment, built from scratch using Gym and Pygame. This environment can be integrated into any reinforcement learning algorithm, such as DQN, PPO, or A2C.
In next blog, we will learn to build PPO agent to play with this snake.