Reinforcement learning 22 Colaboratory + CartPole + ChainerRL + A3C

It is assumed that you have completed reinforcement learning 21. A3C is Asynchronous Advantage Actor-Critic Is an abbreviation for. Click here for a detailed explanation. [Reinforcement learning] A3C to learn while implementing [Stick with CartPole: Complete with 1 file]

As with 21, I made chainerRL a notebook as it is. It took some time and I got stuck in the 90 minute rule, so I did it in a small size.

Google drive mount

!ln -s gdrive/My\ Drive mydrive

program install

!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!pip install pyvirtualdisplay > /dev/null 2>&1
!pip -q install JSAnimation
!pip -q install chainerrl

Main program An example of training A3C against OpenAI Gym Envs.

This script is an example of training a A3C agent against OpenAI Gym envs. Both discrete and continuous action spaces are supported.

modules import

from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import
from builtins import *  # NOQA
from future import standard_library
standard_library.install_aliases()  # NOQA
import argparse
import os
import sys

import chainer
from chainer import functions as F
from chainer import links as L
import gym
import numpy as np

import chainerrl
from chainerrl.agents import a3c
from chainerrl import experiments
from chainerrl import links
from chainerrl import misc
from chainerrl.optimizers.nonbias_weight_decay import NonbiasWeightDecay
from chainerrl.optimizers import rmsprop_async
from chainerrl import policies
from chainerrl.recurrent import RecurrentChainMixin
from chainerrl import v_function

Class A3CFFSoftmax An example of A3C feedforward softmax policy.

class A3CFFSoftmax(chainer.ChainList, a3c.A3CModel):
    def __init__(self, ndim_obs, n_actions, hidden_sizes=(200, 200)):
        self.pi = policies.SoftmaxPolicy(
            model=links.MLP(ndim_obs, n_actions, hidden_sizes))
        self.v = links.MLP(ndim_obs, 1, hidden_sizes=hidden_sizes)
        super().__init__(self.pi, self.v)

    def pi_and_v(self, state):
        return self.pi(state), self.v(state)

Class A3CFFMellowmax An example of A3C feedforward mellowmax policy.

class A3CFFMellowmax(chainer.ChainList, a3c.A3CModel):
    def __init__(self, ndim_obs, n_actions, hidden_sizes=(200, 200)):
        self.pi = policies.MellowmaxPolicy(
            model=links.MLP(ndim_obs, n_actions, hidden_sizes))
        self.v = links.MLP(ndim_obs, 1, hidden_sizes=hidden_sizes)
        super().__init__(self.pi, self.v)

    def pi_and_v(self, state):
        return self.pi(state), self.v(state)

Class A3CLSTMGaussian An example of A3C recurrent Gaussian policy.

class A3CLSTMGaussian(chainer.ChainList, a3c.A3CModel, RecurrentChainMixin):
    def __init__(self, obs_size, action_size, hidden_size=200, lstm_size=128):
        self.pi_head = L.Linear(obs_size, hidden_size)
        self.v_head = L.Linear(obs_size, hidden_size)
        self.pi_lstm = L.LSTM(hidden_size, lstm_size)
        self.v_lstm = L.LSTM(hidden_size, lstm_size)
        self.pi = policies.FCGaussianPolicy(lstm_size, action_size)
        self.v = v_function.FCVFunction(lstm_size)
        super().__init__(self.pi_head, self.v_head,
                         self.pi_lstm, self.v_lstm, self.pi, self.v)

    def pi_and_v(self, state):

        def forward(head, lstm, tail):
            h = F.relu(head(state))
            h = lstm(h)
            return tail(h)

        pout = forward(self.pi_head, self.pi_lstm, self.pi)
        vout = forward(self.v_head, self.v_lstm, self.v)

        return pout, vout



import logging

parser = argparse.ArgumentParser()
parser.add_argument('--processes', type=int,default=8)
parser.add_argument('--env', type=str, default='CartPole-v0')
parser.add_argument('--arch', type=str, default='FFSoftmax',choices=('FFSoftmax', 'FFMellowmax', 'LSTMGaussian'))
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--outdir', type=str, default='mydrive/OpenAI/CartPole/result-a3c')
parser.add_argument('--t-max', type=int, default=5)
parser.add_argument('--beta', type=float, default=1e-2)    
parser.add_argument('--profile', action='store_true')
parser.add_argument('--steps', type=int, default=8 * 10 ** 7)
parser.add_argument('--eval-interval', type=int, default=10 ** 5)
parser.add_argument('--eval-n-runs', type=int, default=10)
parser.add_argument('--reward-scale-factor', type=float, default=1e-2)
parser.add_argument('--rmsprop-epsilon', type=float, default=1e-1)
parser.add_argument('--render', action='store_true', default=False)
parser.add_argument('--lr', type=float, default=7e-4)
parser.add_argument('--weight-decay', type=float, default=0.0)
parser.add_argument('--demo', action='store_true', default=False)
parser.add_argument('--load', type=str, default='')
parser.add_argument('--logger-level', type=int, default=logging.INFO)
parser.add_argument('--monitor', action='store_true')

Where you want to change

args =parser.parse_args([--env].[CartPole-v0'])

To do.

args = parser.parse_args(['--steps','300000','--eval-interval','10000'])
logging.basicConfig(level=args.logger_level, stream=sys.stdout, format='')

Set a random seed used in ChainerRL.

If you use more than one processes, the results will be no longer

deterministic even with the same random seed.


Set different random seeds for different subprocesses.

If seed=0 and processes=4, subprocess seeds are [0, 1, 2, 3].

If seed=1 and processes=4, subprocess seeds are [4, 5, 6, 7].

process_seeds = np.arange(args.processes) + args.seed * args.processes
assert process_seeds.max() < 2 ** 32
if not os.path.exists(args.outdir):


def make_env(process_idx, test):
    env = gym.make(args.env)
    # Use different random seeds for train and test envs
    process_seed = int(process_seeds[process_idx])
    env_seed = 2 ** 32 - 1 - process_seed if test else process_seed
    # Cast observations to float32 because our model uses float32
    env = chainerrl.wrappers.CastObservationToFloat32(env)
    if args.monitor and process_idx == 0:
        env = chainerrl.wrappers.Monitor(env, args.outdir)
    if not test:
        # Scale rewards (and thus returns) to a reasonable range so that
        # training is easier
        env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor)
    if args.render and process_idx == 0 and not test:
        env = chainerrl.wrappers.Render(env)
    return env

Select a model by type of action.

sample_env = gym.make(args.env)
timestep_limit = sample_env.spec.tags.get(
obs_space = sample_env.observation_space
action_space = sample_env.action_space

# Switch policy types accordingly to action space types
if args.arch == 'LSTMGaussian':
    model = A3CLSTMGaussian(obs_space.low.size, action_space.low.size)
elif args.arch == 'FFSoftmax':
    model = A3CFFSoftmax(obs_space.low.size, action_space.n)
elif args.arch == 'FFMellowmax':
    model = A3CFFMellowmax(obs_space.low.size, action_space.n)


opt = rmsprop_async.RMSpropAsync(, eps=args.rmsprop_epsilon, alpha=0.99)
if args.weight_decay > 0:


agent = a3c.A3C(model, opt, t_max=args.t_max, gamma=0.99,
if args.load:



import pandas as pd
import glob
import os
score_files = glob.glob(args.outdir+'/scores.txt')
score_file = score_files[-1]
df = pd.read_csv(score_file, delimiter='\t' )


from pyvirtualdisplay import Display
display = Display(visible=0, size=(1024, 768))

from JSAnimation.IPython_display import display_animation
from matplotlib import animation
import matplotlib.pyplot as plt
%matplotlib inline

frames = []
env = gym.make(args.env)

process_seeds = np.arange(args.processes) + args.seed  * args.processes
assert process_seeds.max() < 2 ** 32
env_seed = int(process_seeds[0])
env = chainerrl.wrappers.CastObservationToFloat32(env)
env = chainerrl.wrappers.ScaleReward(env, args.reward_scale_factor)

envw = gym.wrappers.Monitor(env, args.outdir, force=True)

for i in range(3):
    obs = envw.reset()
    done = False
    R = 0
    t = 0
    while not done and t < 200:
        frames.append(envw.render(mode = 'rgb_array'))
        action = agent.act(obs)
        obs, r, done, _ = envw.step(action)
        R += r
        t += 1
    print('test episode:', i, 'R:', R)

from IPython.display import HTML
plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0),dpi=72)
patch = plt.imshow(frames[0])
def animate(i):
anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),interval=50)'/test.mp4')

