With our game finished, let's move on to training an AI to play it! For this project, the plan is to use the NEAT algorithm.

Step 1: Setting up our file

  1. Make a copy of the main.py file called eval_genomes.py.
  2. Rather than having a main function, we’re going to use an eval_genomes function that takes in a population of genomes and the configuration for the NEAT algorithm and runs the game with these genomes. We start by making a simple change: replacing the main method function header with eval_genomes(genomes, config).
  3. For every genome, we’re going to need to create a corresponding bird and neural net (the phenotype of our genome).
# Setup genomes
 ge = []
 nets = []
 for genome_id, genome in genomes:
     Bird.birds.append(Bird(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2, Bird.COLORS[genome_id % 6]))
     ge.append(genome)
     nets.append(neat.nn.FeedForwardNetwork.create(genome, config))
     genome.fitness = 0

Step 2: Defining a fitness function

Then, we have to implement our fitness function. In Flappy Bird, this is going to be pretty straightforward—the fitness of a bird is its score when it collides with a pipe.

  1. Create a method called remove_bird that takes in the index of the bird, the list of genomes, the list of neural nets, and the current score.
  2. Set the bird’s genome to have fitness equal to the score.
  3. Remove the bird and its corresponding genome and neural net.

Step 3: Deciding when to jump

When we implemented the game before, we used user input to tell a bird to jump. Now, we want to implement a way for the AI to decide when a bird should jump or not.

So, we pass a set of inputs into the neural network and take a look at its output. If the output is above a threshold (≥ 0.5), then we jump, otherwise we do nothing.

<aside> 💡 Before we go on, we need to decide is what to use as inputs! It might be helpful to think about what information you as a player might need when playing Flappy Bird to decide whether to jump or not.

</aside>

One set of inputs might be:

  1. the bird’s y position
  2. the position of the top pipe
  3. the position of the bottom pipe
import pygame, random, neat, os
from classes.background import Background
from classes.bird import Bird
from classes.pipe import Pipe
pygame.init()

SCREEN_WIDTH, SCREEN_HEIGHT = 500, 768
SCREEN = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
FONT = pygame.font.Font('freesansbold.ttf', 72)
pygame.display.set_caption("NEAT - Flappy Bird")

def display_score(score):   
	score_img = FONT.render("{}".format(score), True, (255, 255, 255))   
	SCREEN.blit(score_img, (SCREEN_WIDTH // 2, 60))

**def remove_bird(i, genomes, nets, score):   
	genomes[i].fitness += score   
	Bird.birds.pop(i)   genomes.pop(i)   
	nets.pop(i)**

def eval_genomes(genomes, config):   
	FPS = 60   
	frame = 0   
	run = True   

clock = pygame.time.Clock()   

*# Initialize the background*   
bg = Background(SCREEN_WIDTH, SCREEN_HEIGHT)     

*# Setup genomes*   
Bird.birds = []   
Pipe.pipes = []   

**ge = []   nets = []   
for genome_id, genome in genomes:       
	Bird.birds.append(Bird(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2, Bird.COLORS[genome_id % 6]))       
	ge.append(genome)       
	nets.append(neat.nn.FeedForwardNetwork.create(genome, config))       
	genome.fitness = 0**   

*# Start score at 0*   
score = 0   
while run:       
	*# Event handling*       
	for event in pygame.event.get():           
		if event.type == pygame.QUIT:               
			run = False               
			pygame.quit() 
      
	if len(Pipe.pipes) == 0 or Pipe.pipes[-1].right() < SCREEN_WIDTH - 300:           
		bottom_y = random.randint(300, SCREEN_HEIGHT - 200)           
		top_y = random.randint(100, bottom_y - 200)           
		pipe = Pipe(SCREEN_WIDTH, bottom_y, top_y)    
   
	**for i, bird in enumerate(Bird.birds):           
		for pipe in Pipe.pipes:               
			if pipe.right() >= SCREEN_WIDTH // 2:                   
				closest_pipe = pipe                   
				break           
			output = nets[i].activate((bird.rect.y, abs(closest_pipe.top_pipe_y() - bird.rect.y), (closest_pipe.bottom_pipe_y() - bird.rect.y)))		if output[0] > 0.5:               
			bird.jump()**       

	*# Updating and drawing*       
	dt = 1 / 60             

	SCREEN.fill((255, 255, 255)) *# Clear background*             
	bg.update(dt)       
	bg.draw(SCREEN)       

	for pipe in Pipe.pipes:           
		pipe.update(dt)           
		pipe.draw(SCREEN)                 
	
	for i, bird in enumerate(Bird.birds):           
		bird.update(dt)           *# Collisions*           
		for pipe in Pipe.pipes:               
			if pipe.collide(bird):                   
				**remove_bird(i, ge, nets, score)**           
		if bird.rect.top < 0 or bird.rect.bottom > SCREEN_HEIGHT:               
			**remove_bird(i, ge, nets, score)**           
		bird.draw(SCREEN)       

	if len(Bird.birds) == 0:           
		break             

	for pipe in Pipe.pipes:           
		if pipe.right() < SCREEN_WIDTH // 2 and not pipe.scored:               
			pipe.scored = 1               
			score += 1       
	
	display_score(score)       
	pygame.display.update()       
	clock.tick(FPS)