It is assumed that you have completed reinforcement learning 16. After trial and error, it looks like this. The point is that you can save and load files directly to Google Drive. I can't erase the figure that appears below the last animation.
import google.colab.drive
google.colab.drive.mount('gdrive')
!ln -s gdrive/My\ Drive mydrive
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!pip install pyvirtualdisplay > /dev/null 2>&1
!pip -q install JSAnimation
!pip -q install chainerrl
gamename='CartPole-v0'
# Set the discount factor that discounts future rewards.
gamma = 0.95
# Use epsilon-greedy for exploration
myepsilon=0.3
mySteps=20000 # Train the agent for 2000 steps
my_eval_n_episodes=1 # 10 episodes are sampled for each evaluation
my_eval_max_episode_len=200 # Maximum length of each episodes
my_eval_interval=1000 # Evaluate the agent after every 1000 steps
myOutDir='mydrive/OpenAI/CartPole/result' # Save everything to 'result' directory
myAgentDir='mydrive/OpenAI/CartPole/agent' # Save Agent to 'agent' directory
myAnimName='mydrive/OpenAI/CartPole/movie_cartpole.mp4'
myScoreName="mydrive/OpenAI/CartPole/result/scores.txt"
Program
import
import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl
import gym
import numpy as np
env initialize
env = gym.make(gamename)
print('observation space:', env.observation_space)
print('action space:', env.action_space)
obs = env.reset()
print('initial observation:', obs)
action = env.action_space.sample()
obs, r, done, info = env.step(action)
print('next observation:', obs)
print('reward:', r)
print('done:', done)
print('info:', info)
Deep Q Network setting
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n
q_func = chainerrl.q_functions.FCStateQFunctionWithDiscreteAction(
obs_size, n_actions,
n_hidden_layers=2, n_hidden_channels=50)
Use Adam to optimize q_func. eps=1e-2 is for stability.
optimizer = chainer.optimizers.Adam(eps=1e-2)
optimizer.setup(q_func)
Agent Setting DQN uses Experience Replay.
Specify a replay buffer and its capacity.
Since observations from CartPole-v0 is numpy.float64 while
Chainer only accepts numpy.float32 by default, specify a converter as a feature extractor function phi.
explorer = chainerrl.explorers.ConstantEpsilonGreedy(
epsilon=myepsilon, random_action_func=env.action_space.sample)
replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)
phi = lambda x: x.astype(np.float32, copy=False)
agent = chainerrl.agents.DoubleDQN(
q_func, optimizer, replay_buffer, gamma, explorer,
replay_start_size=500, update_interval=1,
target_update_interval=100, phi=phi)
Train
Set up the logger to print info messages for understandability.
import logging
import sys
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')
chainerrl.experiments.train_agent_with_evaluation(
agent, env,steps=mySteps,eval_n_steps=None,eval_n_episodes=my_eval_n_episodes,eval_max_episode_len=my_eval_max_episode_len,
eval_interval=my_eval_interval,outdir=myOutDir)
agent.save(myAgentDir)
Data Table
import pandas as pd
import glob
import os
score_files = glob.glob(myScoreName)
score_files.sort(key=os.path.getmtime)
score_file = score_files[-1]
df = pd.read_csv(score_file, delimiter='\t' )
df
figure Average_Q
df.plot(x='steps',y='average_q')
Test
import2
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1024, 768))
display.start()
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
import matplotlib.pyplot as plt
%matplotlib inline
Test Program
frames = []
env = gym.make(gamename)
envw = gym.wrappers.Monitor(env, myOutDir, force=True)
for i in range(3):
obs = envw.reset()
done = False
R = 0
t = 0
while not done and t < 200:
frames.append(envw.render(mode = 'rgb_array'))
action = agent.act(obs)
obs, r, done, _ = envw.step(action)
R += r
t += 1
print('test episode:', i, 'R:', R)
agent.stop_episode()
envw.render()
envw.close()
from IPython.display import HTML
plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0),dpi=72)
patch = plt.imshow(frames[0])
plt.axis('off')
def animate(i):
patch.set_data(frames[i])
anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),interval=50)
anim.save(myAnimName)
HTML(anim.to_jshtml())
Recommended Posts