Apprentissage par renforcement: accélérer l'itération de la valeur

introduction

Ces dernières années, en raison du succès d'AlphaGo et de DQN, le domaine de l'apprentissage par renforcement profond sans modèle a été activement étudié. Ces algorithmes sont l'une des approches efficaces lorsque l'espace état-action est grand dans des circonstances ou lorsque la modélisation mathématique de la dynamique est difficile. Cependant, parmi les problèmes rencontrés dans la réalité, il est relativement facile de modéliser mathématiquement l'environnement, et il existe de nombreux cas où l'espace état-action peut être réduit en imaginant. Pour de tels problèmes, je pense que l'utilisation de l'apprentissage par renforcement de table basé sur un modèle présente un grand avantage en termes de coûts de développement et d'exploitation.

Cependant, la taille de l'espace d'action d'état qui peut être gérée dans l'apprentissage du renforcement des tables dépend grandement de la vitesse du programme, et l'accélération est très importante. Par conséquent, dans cet article, nous allons présenter le savoir-faire pour exécuter la méthode d'itération __value __, qui est l'algorithme de base de l'apprentissage par renforcement, à grande vitesse. Au final, nous avons pu atteindre __500 fois plus vite __ que l'implémentation naïve.

Contexte

Processus décisionnel de Markov

Les processus décisionnels de Markov (MDP) sont un cadre utilisé dans la définition des problèmes de l'apprentissage par renforcement. L '«environnement» prend un état $ s $ à chaque instant, et l' «agent» décisionnel sélectionne arbitrairement l'action $ a $ disponible dans cet état. Après cela, l'environnement passe aléatoirement à un nouvel état, moment auquel l'agent reçoit une récompense de $ r $ correspondant à la transition d'état. Le paramètre de base du problème dans MDP est de trouver une mesure qui est une correspondance (distribution de probabilité) avec l'action optimale prise par un agent dans un certain état. Ensuite, si la fonction objectif est la récompense cumulative de remise, le problème de la recherche de la fonction de valeur optimale est le suivant.

\pi^* = \text{arg}\max_{\pi}  \text{E}_{s \sim \mu, \tau \sim \pi}[\sum _t \gamma^t r(s_t, a_t)|s_0=s] = \text{arg}\max_{\pi}  \text{E}_{s \sim \mu} [V_{\pi}(s)]

La fonction de valeur d'état $ V (s) $ et la fonction de valeur d'action $ Q (s, a) $ sont définies comme suit.

V_{\pi}(s) = \text{E}_{\tau \sim \pi}[\sum _t \gamma^t r(s_t, a_t)|s_0=s] \\
Q_{\pi}(s, a) = \text{E}_{\tau \sim \pi}[\sum _t \gamma^t r(s_t, a_t)|s_0=s, a_0=a]

Méthode de répétition des valeurs

Politique optimale $ V ^ * (s) = V_ {\ pi ^ *} (s) $ pour la fonction de valeur d'état dans $ \ pi ^ * $, $ Q ^ * (s, a) = Q_ {\ pour la fonction de valeur d'action Défini comme pi ^ *} (s, a) $. À partir de la définition de la fonction de valeur, nous pouvons voir que la fonction de valeur optimale satisfait l'équation de Bellman suivante.

Q^*(s, a) = \sum_{s', r} p(s', r|s, a) [r + \gamma V^*(s')] \\
V^*(s) = \max _a Q^*(s, a)

L'itération de valeur est un algorithme qui commence par une valeur initiale appropriée et adapte à plusieurs reprises l'équation de Belman pour mettre à jour alternativement $ V $ et $ Q $. Le pire montant de calcul est un polypole pour le nombre d'états et le nombre d'actions. Si la fonction de valeur d'action $ Q ^ * (s, a) $ peut être obtenue, la mesure optimale peut être obtenue en sélectionnant l'action $ a $ qui peut être effectuée dans l'état $ s $ et ayant la valeur d'action la plus élevée. Je peux le faire.

\pi(a|s) = \text{arg}\max _a Q^*(s,a)

algorithme

Préparation de l'expérience

La méthode d'itération de valeur est une idée très simple, mais il existe de nombreuses implémentations possibles. Ici, je voudrais présenter quelques implémentations et leurs vitesses de traitement en détail. Créez un MDP expérimental pour comparer chaque implémentation. Pour faciliter la mise en œuvre, considérons un MDP déterministe, un monde dans lequel la récompense $ r $ et l'état suivant $ s '$ sont déterminés de manière déterministe pour une action $ a $. À ce stade, l'équation de Belman est simplifiée comme suit.

Q^*(s, a) = r + \gamma V^*(s')

Le MDP déterministe peut être représenté par un graphique dans lequel l'état est un nœud, l'action est une arête et la récompense est un poids d'arête (attribut). La fonction suivante crée le MDP à utiliser dans l'expérience.

import networkx as nx
import random

def create_mdp(num_states, num_actions, reward_ratio=0.01, neighbor_count=30):
    get_reward = lambda: 1.0 if random.random() < reward_ratio else 0.0
    get_neighbor = lambda u: random.randint(u - neighbor_count, u + neighbor_count) % (num_states - 1)
    edges = [
        (i, (i + 1) % (num_states - 1), get_reward())
        for i in range(num_states)
    ]
    for _ in range(num_states * (num_actions - 1)):
        u = random.randint(0, num_states - 1)
        v = get_neighbor(u)
        r = get_reward()
        edges.append((u, v, r))
    G = nx.DiGraph()
    G.add_weighted_edges_from(edges)
    return G

Le nombre d'états et le nombre d'actions (le nombre moyen d'actions dans chaque état) sont spécifiés, et un graphique aléatoire fortement connecté et une récompense éparse sont générés. Il est représenté par DiGraph, où le nœud de networkx est l'état, le bord est l'action et l'attribut poids du bord est la récompense.

Dans de futures expériences, nous utiliserons MDP avec 10000 états et une moyenne de 3 actions pour chaque état.

num_states = 10000
num_actions = 3
G = create_mdp(num_states, num_actions)

Implémentation de la méthode d'itération de valeur naïve

L'algorithme le plus simple est une méthode de répétition de "mise à jour de la valeur d'état de tous les états" et de "mise à jour de la valeur d'action de toutes les actions", et est parfois appelée programmation dynamique synchrone.

class NonConvergenceError(Exception):
    pass

class SyncDP:
    
    def __init__(self, G, gamma, max_sweeps, threshold):
        self.G = G
        self.gamma = gamma
        self.max_sweeps = max_sweeps
        self.threshold = threshold
        self.V = {state : 0 for state in G.nodes}
        self.TD = {state : 0 for state in G.nodes}
        self.Q = {(state, action) : 0 for state, action in G.edges}

    def get_reward(self, s, a):
        return self.G.edges[s, a]['weight']

    def sweep(self):
        for state in self.G.nodes:
            for action in self.G.successors(state):
                self.Q[state, action] = self.get_reward(state, action) + self.gamma * self.V[action]
        for state in self.G.nodes:
            v_new = max([self.Q[state, action] for action in self.G.successors(state)])
            self.TD[state] = abs(self.V[state] - v_new)
            self.V[state] = v_new

    def run(self):
        for _ in range(self.max_sweeps):
            self.sweep()
            if (np.array(list(self.TD.values())) < self.threshold).all():
                return self.V
        raise NonConvergenceError

Les paramètres du taux d'actualisation gamma, du seuil de convergence et du nombre maximal de balayages sont initialement déterminés par les exigences de l'application, mais ici, définissez les valeurs appropriées.

gamma = 0.95
threshold = 0.01
max_sweeps = 1000

Lorsqu'il est exécuté dans cette condition, le temps de traitement est le suivant.

%timeit V = SyncDP(G, gamma, max_sweeps, threshold).run()
8.83 s ± 273 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Si le nombre d'états est de 10000 et que cela prend autant de temps, les applications qui peuvent être utilisées seront probablement assez limitées. Ce temps de traitement sera la base des améliorations futures.

Asynchronous Dynamic Programming (Async DP) Étant donné que l'algorithme de mise à jour de la valeur synchrone met à jour tous les états en un seul balayage, un seul balayage peut prendre un temps considérable si le nombre d'états est très grand. Asynchronous DP met à jour de manière itérative la valeur d'état sur l'impression. En d'autres termes, au lieu de préparer un nouveau tableau pour stocker les valeurs mises à jour à chaque fois comme DP synchrone et de mettre à jour toutes les valeurs d'état et de les stocker dans le nouveau tableau, calculez d'ici là à chaque mise à jour. Répétez les mises à jour des valeurs, en tirant parti des différentes valeurs d'état disponibles. Il est nécessaire de continuer à mettre à jour tous les états pour assurer la convergence, mais l'ordre de mise à jour peut être librement sélectionné. Par exemple, vous pouvez accélérer les mises à jour de valeur en ignorant les conditions qui sont moins pertinentes pour la meilleure stratégie.

Prioritized Sweeping Pour DP asynchrone, l'ordre des mises à jour des valeurs peut être arbitraire. Lors de la mise à jour de la valeur, tous les états ne sont pas également utiles pour mettre à jour la valeur d'autres états, et on s'attend à ce que certains états aient un impact significatif sur la valeur d'autres états. Par exemple, dans MDP où vous pouvez obtenir la récompense éparse à laquelle vous pensez, il est important de propager efficacement l'état récompensé à d'autres états. Par conséquent, l'algorithme suivant utilisant une file d'attente prioritaire peut être envisagé.

  1. Gérez la quantité de changement due à la mise à jour de la valeur dans tous les états avec une file d'attente prioritaire
  2. Mettez à jour la valeur de l'état en haut de la file d'attente
  3. Si le montant de la modification depuis la dernière mise à jour de la valeur dépasse le seuil, poussez la paire état / montant de modification dans la file d'attente.

Cet algorithme est appelé balayage prioritaire. La mise en œuvre ressemble à ceci:

import heapq

class PrioritizedDP(SyncDP):
    def run(self):
        self.sweep()
        pq = [
            (-abs(td_error), state)
            for state, td_error in self.TD.items()
            if abs(td_error) > self.threshold
            ]
        heapq.heapify(pq)
        while pq:
            _, action = heapq.heappop(pq)
            if self.TD[action] < self.threshold:
                continue
            self.TD[action] = 0
            for state in self.G.predecessors(action):
                self.Q[state, action] = self.get_reward(state, action) + self.gamma * self.V[action]
                v_new = max([self.Q[state, action] for action in self.G.successors(state)])
                td_error = abs(v_new - self.V[state])
                self.TD[state] += td_error
                if td_error > self.threshold:
                    heapq.heappush(pq, (-td_error, state))
                self.V[state] = v_new
        return self.V

Tout d'abord, la fonction de balayage est utilisée pour mettre à jour la valeur de tous les états, et le tas est construit en fonction de l'état dans lequel l'erreur TD dépasse le seuil. Après cela, répétez la mise à jour jusqu'à ce que la file d'attente soit épuisée. Tout d'abord, mettez à jour la valeur de l'action $ Q $ afin que l'état sorti de la file d'attente devienne l'état suivant (= action). Puis mettez à jour $ V $ pour la valeur d'état qui dépend de la valeur d'action mise à jour. Si la différence entre avant et après la mise à jour (td_error) dépasse le seuil, elle sera poussée dans la file d'attente. Le temps de traitement est le suivant, et nous avons pu atteindre environ deux fois la vitesse.

%timeit V = PrioritizedDP(G, gamma, max_sweeps, threshold).run()
4.06 s ± 115 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Calcul vectoriel

Jusqu'à présent, nous avons implémenté des algorithmes basés sur les objets networkx, mais les méthodes fréquemment appelées telles que les successeurs prennent la plupart du temps. Voyons comment utiliser le calcul vectoriel par numpy tout en rendant la structure de données plus efficace. En exprimant le graphique sous la forme d'un tableau numpy, il devient possible d'utiliser le calcul vectoriel pour le calcul de la valeur d'action comme suit.

class ArraySyncDP:
    
    def __init__(self, A : ArrayGraph, gamma, max_sweeps, threshold):
        self.A = A
        self.gamma = gamma
        self.max_sweeps = max_sweeps
        self.threshold = threshold
        self.V = np.zeros(A.num_states)
        self.TD = np.full(A.num_states, np.inf)
        self.Q = np.zeros(A.num_actions)

    def run(self):
        for _ in range(self.max_sweeps):
            self.Q[:] = self.A.reward + self.gamma * self.V[self.A.action2next_state]
            for state_id in range(self.A.num_states):
                start, end = self.A.state2action_start[state_id], self.A.state2action_start[state_id + 1]
                v_new = self.Q[start : end].max()
                self.TD[state_id] = abs(self.V[state_id] - v_new)
                self.V[state_id] = v_new

            if (self.TD < self.threshold).all():
                return self.V
        raise NonConvergenceError
%timeit V = ArraySyncDP(A, gamma, max_sweeps, threshold).run()
3.5 s ± 99.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

C'est légèrement plus rapide que le balayage prioritaire, mais ce n'est pas très efficace. Il semble que la raison principale est que nous n'avons pas été en mesure de tout convertir en vecteurs et que nous utilisons encore pendant quelques minutes pour mettre à jour la valeur de l'état.

Cython Par nature, les tableaux contiennent leurs types de données de base dans des zones contiguës de la mémoire, ils devraient donc être accessibles plus rapidement que les listes qui référencent des objets dispersés dans la mémoire. Cependant, Python semble être plus lent à accéder aux éléments que les listes et les dictionnaires car il se convertit en objets Python pour référencer des éléments individuels du tableau, ce qui entraîne une surcharge.

Pensons donc à l'utilisation de Cython pour accélérer l'accès aux tableaux. Cython est un compilateur qui convertit Python annoté de type en une extension compilée. Le module d'extension converti peut être chargé par importation de la même manière qu'un module Python normal. Si vous utilisez Cython, il semble que vous puissiez accélérer le processus d'accès aux éléments en utilisant le tableau numpy comme interface pour les raisons suivantes.

Implémentez DP asynchrone en Cython. Je voudrais utiliser une file d'attente prioritaire, mais je ne fais que sauter les mises à jour de la valeur d'état avec une petite erreur TD en raison du traitement des objets Python.

%%cython
import numpy as np
cimport numpy as np
cimport cython

ctypedef np.float64_t FLOAT_t
ctypedef np.int64_t INT_t

@cython.boundscheck(False)
@cython.wraparound(False)
def cythonic_backup(
    FLOAT_t[:] V, FLOAT_t[:] Q, FLOAT_t[:] TD, FLOAT_t[:] reward,
    INT_t[:] state2action_start, INT_t[:] action2next_state,
    INT_t[:] next_state2inv_action, INT_t[:, :] inv_action2state_action,
    FLOAT_t gamma, FLOAT_t threshold
):
    cdef INT_t num_updates, state, action, next_state, inv_action, start_inv_action, end_inv_action, start_action, end_action
    cdef FLOAT_t v
    num_updates = 0
    for next_state in range(len(V)):
        if TD[next_state] < threshold:
            continue
            
        num_updates += 1
        TD[next_state] = 0
        start_inv_action = next_state2inv_action[next_state]
        end_inv_action = next_state2inv_action[next_state + 1]
        for inv_action in range(start_inv_action, end_inv_action):
            state = inv_action2state_action[inv_action][0]
            action = inv_action2state_action[inv_action][1]
            Q[action] = reward[action] + gamma * V[next_state]
            start_action = state2action_start[state]
            end_action = state2action_start[state + 1]
            v = -1e9
            for action in range(start_action, end_action):
                if v < Q[action]:
                    v = Q[action]
            if v > V[state]:
                TD[state] += v - V[state]
            else:
                TD[state] += V[state] - v
            V[state] = v
    return num_updates
class CythonicAsyncDP(ArraySyncDP):
    def run(self):
        A = self.A
        for i in range(self.max_sweeps):
            num_updates = cythonic_backup(
                self.V, self.Q, self.TD, A.reward, A.state2action_start, A.action2next_state,
                A.next_state2inv_action, A.inv_action2state_action, self.gamma, self.threshold
            )
            if num_updates == 0:
                return self.V
        raise NonConvergenceError
%timeit V = CythonicAsyncDP(A, gamma, max_sweeps, threshold).run()
18.6 ms ± 947 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Une accélération spectaculaire de 18,6 ms a été atteinte. Avec cette vitesse de traitement, en fonction de l'utilisation prévue, même les problèmes avec environ 1 million d'états peuvent être traités de manière adéquate.

Les références

[1] Python haute performance, Micha Gorelick et Ian Ozsvald, O'Reilly Japon, 2015. [2] Reinforcement Learning, R. S. Sutton and A. G. Barto, The MIT Press, 2018.

Recommended Posts

Apprentissage par renforcement: accélérer l'itération de la valeur
Apprentissage par renforcement futur_2
Apprentissage par renforcement futur_1
Apprentissage amélioré 1 installation de Python
Renforcer l'apprentissage 3 Installation d'OpenAI
Renforcer l'apprentissage de la troisième ligne
[Renforcer l'apprentissage] Tâche de bandit
Apprentissage amélioré Python + Unity (apprentissage)
Renforcer l'apprentissage 1 édition introductive
J'ai essayé les réseaux d'itération de valeur
Apprentissage amélioré 7 Sortie du journal des données d'apprentissage
Renforcer l'apprentissage 28 collaboratif + OpenAI + chainerRL
Renforcer l'apprentissage 19 Colaboratory + Mountain_car + ChainerRL
Renforcement de l'apprentissage 2 Installation de chainerrl
[Renforcer l'apprentissage] Suivi par multi-agents
Renforcer l'apprentissage 6 First Chainer RL
Apprentissage amélioré à partir de Python
Renforcer l'apprentissage 20 Colaboratoire + Pendule + ChainerRL
Apprentissage par renforcement 5 Essayez de programmer CartPole?
Apprentissage par renforcement 9 Remodelage magique ChainerRL
Renforcer l'apprentissage Apprendre d'aujourd'hui
Renforcer l'apprentissage 4 CartPole première étape
Apprentissage par renforcement profond 1 Introduction au renforcement de l'apprentissage
DeepMind Enhanced Learning Framework Acme
Renforcer l'apprentissage 21 Colaboratoire + Pendule + ChainerRL + A2C
TF2RL: bibliothèque d'apprentissage améliorée pour TensorFlow2.x
Apprentissage par renforcement 34 Créez des vidéos d'agent en continu
Renforcer l'apprentissage 13 Essayez Mountain_car avec ChainerRL.
Construction d'un environnement d'apprentissage amélioré Python + Unity
Renforcer l'apprentissage 22 Colaboratory + CartPole + ChainerRL + A3C
Explorez le labyrinthe avec l'apprentissage augmenté
Renforcer l'apprentissage 8 Essayez d'utiliser l'interface utilisateur de Chainer
Renforcer l'apprentissage 24 Colaboratory + CartPole + ChainerRL + ACER
Apprentissage par renforcement 3 Méthode de planification dynamique / méthode TD
Deep Strengthening Learning 3 Édition pratique: Briser des blocs
J'ai essayé l'apprentissage par renforcement avec PyBrain
Apprenez en faisant! Apprentissage par renforcement profond_1