from collections import deque, namedtuple
from pathlib import Path
import torch
import torch.nn as nn
import random
import time
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from tqdm.notebook import tnrange as trange
import pickle
from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT
import gym
import numpy as np
import collections
import cv2
import matplotlib.pyplot as plt
from matplotlib import animation
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env=None, skip=4):
"""Return only every `skip`-th frame"""
super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = collections.deque(maxlen=2)
self._skip = skip
def step(self, action):
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, info
def reset(self):
"""Clear past frame buffer and init to first obs"""
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class ProcessFrame84(gym.ObservationWrapper):
"""
Downsamples image to 84x84
Greyscales image
Returns numpy array
"""
def __init__(self, env=None):
super(ProcessFrame84, self).__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def observation(self, obs):
return ProcessFrame84.process(obs)
@staticmethod
def process(frame):
if frame.size == 240 * 256 * 3:
img = np.reshape(frame, [240, 256, 3]).astype(np.float32)
else:
assert False, "Unknown resolution."
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
x_t = resized_screen[18:102, :]
x_t = np.reshape(x_t, [84, 84, 1])
return x_t.astype(np.uint8)
class ImageToPyTorch(gym.ObservationWrapper):
def __init__(self, env):
super(ImageToPyTorch, self).__init__(env)
old_shape = self.observation_space.shape
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]),
dtype=np.float32)
def observation(self, observation):
return np.moveaxis(observation, 2, 0)
class ScaledFloatFrame(gym.ObservationWrapper):
"""Normalize pixel values in frame --> 0 to 1"""
def observation(self, obs):
return np.array(obs).astype(np.float32) / 255.0
class BufferWrapper(gym.ObservationWrapper):
def __init__(self, env, n_steps, dtype=np.float32):
super(BufferWrapper, self).__init__(env)
self.dtype = dtype
old_space = env.observation_space
self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
old_space.high.repeat(n_steps, axis=0), dtype=dtype)
def reset(self):
self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
return self.observation(self.env.reset())
def observation(self, observation):
self.buffer[:-1] = self.buffer[1:]
self.buffer[-1] = observation
return self.buffer
def make_env(env_name):
env = gym_super_mario_bros.make(env_name)
env = MaxAndSkipEnv(env)
env = ProcessFrame84(env)
env = ImageToPyTorch(env)
env = BufferWrapper(env, 4)
env = ScaledFloatFrame(env)
return JoypadSpace(env, SIMPLE_MOVEMENT)
class DQNSolver(nn.Module):
def __init__(self, input_shape, n_actions):
super(DQNSolver, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
conv_out_size = self._get_conv_out(input_shape)
self.fc = nn.Sequential(
nn.Linear(conv_out_size, 512),
nn.ReLU(),
nn.Linear(512, n_actions)
)
def _get_conv_out(self, shape):
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x):
conv_out = self.conv(x).view(x.size()[0], -1)
return self.fc(conv_out)
device = "cuda" if torch.cuda.is_available() else "cpu"
env = make_env('SuperMarioBros-1-1-v0')
observation_space = env.observation_space.shape
action_space = env.action_space.n
obs = env.reset()
for _ in range(100):
obs, _, done, _ = env.step(random.randint(0, 1))
if done: break
fig, axes = plt.subplots(ncols=4, figsize=(12, 4))
for i in range(4):
axes[i].imshow(obs[i], cmap=plt.get_cmap('gray'))
axes[i].axis('off')
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'done', 'reward'))
class ReplayMemory:
def __init__(self, capacity):
self.memory = deque(maxlen=capacity)
def push(self, *args):
self.memory.append(Transition(*args))
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class DQNAgent:
def __init__(self, seed, input_dim, num_actions, learning_rate,
capacity, batch_size, gamma, sync_every, burnin, learn_every,
epsilon_start, epsilon_min, epsilon_decay,
save_every, logger):
torch.manual_seed(seed)
env.seed(seed + 1)
env.action_space.seed(seed + 2)
np.random.seed(seed + 3)
random.seed(seed + 4)
self.curr_step = 0
self.num_actions = num_actions
self.online_net = DQNSolver(input_dim, num_actions).to(device)
self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=learning_rate)
self.target_net = DQNSolver(input_dim, num_actions).to(device)
self.target_net.load_state_dict(self.online_net.state_dict())
# Q_target parameters are frozen.
for p in self.target_net.parameters():
p.requires_grad = False
self.l1 = nn.SmoothL1Loss().to(device)
self.gamma = gamma
self.memory = ReplayMemory(capacity)
self.batch_size = batch_size
self.sync_every = sync_every
self.burnin = burnin
self.learn_every = learn_every
self.save_every = save_every
self.epsilon = epsilon_start
self.epsilon_start = epsilon_start
self.epsilon_min = epsilon_min
self.epsilon_decay = epsilon_decay
self.logger = logger
self.tmp_dir = Path('./tmp/')
def get_action(self, state, explore=True):
# epsilon-greedy part, we select a random action
if torch.rand(1).item() <= self.epsilon and explore:
action = torch.randint(0, self.num_actions, (1,)).item()
else:
Q_row = self.online_net(state.to(device))
action = torch.argmax(Q_row).unsqueeze(0).unsqueeze(0).item()
if explore:
self.curr_step += 1
return action
def remember(self, state, action, next_state, done, reward):
# keep the memory on the CPU
self.memory.push(state.to(device), action.to(device), next_state.to(device), done.to(device), reward.to(device))
def learn(self):
if self.curr_step < self.burnin:
return np.nan
if self.curr_step % self.learn_every != 0:
return np.nan
# if we don't have enough experience yet, we don't optimize and simply return
if len(self.memory) < self.batch_size:
return np.nan
# sample for memory to create a batch of transitions
transitions = self.memory.sample(self.batch_size)
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
# detailed explanation). This converts batch-array of Transitions to Transition of batch-arrays.
batch = Transition(*zip(*transitions))
state = torch.cat(batch.state)
next_state = torch.cat(batch.next_state)
action = torch.cat(batch.action)
done = torch.cat(batch.done)
reward = torch.cat(batch.reward)
with torch.no_grad():
best_action = self.online_net(next_state).max(1)[1].view(-1, 1)
next_state_value = self.target_net(next_state).gather(1, best_action).view(-1, 1)
target = reward + self.gamma * torch.mul(next_state_value, 1 - done)
#target = reward + torch.mul(
# (self.gamma * self.target_net(next_state).max(1).values.unsqueeze(1)), 1 - done)
current = self.online_net(state).gather(1, action.unsqueeze(-1)).float()
self.optimizer.zero_grad()
loss = self.l1(current, target)
loss.backward()
#### FIXME TEST THIS
#for param in self.online_net.parameters():
# param.grad.data.clamp_(-1, 1)
self.optimizer.step()
self.epsilon *= self.epsilon_decay
self.epsilon = max(self.epsilon, self.epsilon_min)
if self.curr_step % self.sync_every == 0:
self.sync()
if self.curr_step % self.save_every == 0:
self.save(self.curr_step)
logger.info(f'optimizing at step {self.curr_step}, loss: {loss.item():.2f}')
return loss.item()
def sync(self):
logger.info(f"Synchronizing at step {self.curr_step}")
self.target_net.load_state_dict(self.online_net.state_dict())
def save(self, label):
if episode is None:
filenane = "net.pt"
torch.save(self.online_net.state_dict(), filenane)
else:
filenane = self.tmp_dir / f"online-net-{label}.pt"
torch.save(self.online_net.state_dict(), filenane)
filenane = self.tmp_dir / f"target-net-{label}.pt"
torch.save(self.target_net.state_dict(), filenane)
def load(self):
self.online_net.load_state_dict(torch.load("net.pt"))
self.target_net.load_state_dict(self.online_net.state_dict())
class MetricLogger:
def __init__(self, msg_filename='mario.log', data_filename='mario.csv'):
self.msg_filename = msg_filename
with open(self.msg_filename, 'w') as f:
f.write('Simulation starts\n')
self.data_filename = data_filename
with open(self.data_filename, 'w') as f:
f.write('Episode,Length,Reward,Loss,X-Pos,Time\n')
def append(self, episode, length, reward, loss, info):
with open(self.data_filename, 'a') as f:
f.write(f"{episode},{length},{reward},{loss},{info['x_pos']},{info['time']}\n")
def info(self, message):
with open(self.msg_filename, 'a') as f:
f.write(message + '\n')
input_dim = env.observation_space.shape
num_actions = env.action_space.n
print(f"Environment has a input dim {input_dim} and {num_actions} actions")
Environment has a input dim (4, 84, 84) and 7 actions
logger = MetricLogger()
agent = DQNAgent(seed=42, input_dim=input_dim, num_actions=num_actions,
learning_rate=0.00025, capacity=20_000, batch_size=32, gamma=0.9,
sync_every=5_000, burnin=1, learn_every=1, save_every=10_000,
epsilon_start=1, epsilon_min=0.02, epsilon_decay=0.9999,
logger=logger)
test_every = 1_000
episodes = 10000
episode_losses, episode_rewards, episode_lens = [], [], []
for episode in (pbar := trange(1, episodes + 1)):
state, done, episode_len, episode_reward, episode_loss = env.reset(), False, 0, 0.0, 0.0
state = torch.tensor([state], dtype=torch.float)
while not done:
episode_len += 1
try:
env.render()
except:
pass
action = agent.get_action(state)
next_state, reward, done, info = env.step(action)
episode_reward += reward
action = torch.tensor([action], dtype=torch.long)
next_state = torch.tensor([next_state], dtype=torch.float)
reward = torch.tensor([reward], dtype=torch.float).unsqueeze(0)
done = torch.tensor([float(done)], dtype=torch.float).unsqueeze(0)
agent.remember(state, action, next_state, done, reward)
state = next_state
episode_loss += agent.learn()
logger.append(episode, episode_len, episode_reward, episode_loss, info)
episode_losses.append(episode_loss / episode_len)
episode_rewards.append(episode_reward)
episode_lens.append(episode_len)
avg_lens = np.array(episode_lens[-50:]).mean()
avg_rewards = np.array(episode_rewards[-50:]).mean()
pbar.set_description(f"{max(episode_lens)}/{avg_lens:.2f}/{avg_rewards:.2f}/{len(agent.memory)}/{agent.epsilon:.2f} ~ {agent.curr_step}")
HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=10000.0), HTML(value='')))
agent.save(None)
agent.load()
agent.epsilon = 0.05
env = make_env('SuperMarioBros-1-3-v0')
state = env.reset()
state = torch.tensor([state], dtype=torch.float)
frames = [env.render(mode='rgb_array').copy()]
done = False
total_reward = 0.0
while not done:
action = agent.get_action(state)
next_state, reward, done, info = env.step(action)
next_state = torch.tensor([next_state], dtype=torch.float)
total_reward += reward
#print(info, reward)
frames.append(env.render(mode='rgb_array').copy())
state = next_state
time.sleep(0.025)
env.render()
print(f"Total reward: {total_reward}")
env.close()
# np array with shape (frames, height, width, channels)
video = np.array(frames[:])
fig = plt.figure(figsize=(4, 4))
im = plt.imshow(video[0,:,:,:])
plt.axis('off')
plt.close() # this is required to not display the generated image
def init():
im.set_data(video[0,:,:,:])
def animate(i):
im.set_data(video[i,:,:,:])
return im
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],
interval=100)
anim.save('super-mario-video.gif', dpi=80, writer='imagemagick')
filenane = "net.pt"
torch.save(agent.online_net.state_dict(), filenane)