Dans la Procédure pour apprendre et inférer le modèle de traduction anglais-japonais du transformateur avec CloudTPU, le modèle de traduction anglais-japonais du transformateur a été appris avec CloudTPU et l'inférence a également été effectuée. C'était. Cette fois, je vais vous expliquer comment exécuter un transformateur formé par Cloud TPU dans un conteneur Docker local. Le code est ici. https://github.com/yolo-kiyoshi/transformer_python_exec
GCS, supposons que les fichiers se trouvent localement dans la structure de répertoires suivante.
Structure du répertoire
bucket
├── training/
│ └── transformer_ende/
│ ├── checkpoint
│ ├── model.ckpt-****.data-00000-of-00001
│ ├── model.ckpt-****.index
│ └── model.ckpt-****.meta
└── transformer/
└── vocab.translate_jpen.****.subwords
Clonez le référentiel.
git clone https://github.com/yolo-kiyoshi/transformer_python_exec.git
Structure du répertoire
.
├── Dockerfile
├── .env.sample
├── Pipfile
├── Pipfile.lock
├── README.md
├── decode.ipynb
├── docker-compose.yml
├── training/
│ └── transformer_ende/
└── transformer/
Téléchargez le fichier d'informations d'identification du compte de service (json) et placez-le dans le même répertoire que README.md.
Dupliquez et renommez .env.sample
pour créer .env
.
.env
#Décrivez le chemin du fichier d'identification placé au-dessus
GOOGLE_APPLICATION_CREDENTIALS=*****.json
BUDGET_NAME=
#Mêmes paramètres que lors de l'apprentissage avec CloudTPU
PROBLEM=translate_jpen
DATA_DIR=transformer
TRAIN_DIR=training/transformer_ende/
HPARAMS=transformer_tpu
MODEL=transformer
Après avoir exécuté la commande suivante, vous pouvez utiliser Jupyter lab en accédant à http: // localhost: 8080 / lab.
docker-compose up -d
Notebook
Téléchargez localement l'ensemble des fichiers point de contrôle
et des fichiers vocaux
créés pendant le processus d'apprentissage du transformateur à partir de GCS.
#Méthode pour télécharger des fichiers depuis GCS(https://cloud.google.com/storage/docs/downloading-objects?hl=ja)
def download_blob(bucket_name, source_blob_name, destination_file_name):
"""Downloads a blob from the bucket."""
storage_client = storage.Client()
bucket = storage_client.get_bucket(bucket_name)
blob = bucket.blob(source_blob_name)
blob.download_to_filename(destination_file_name)
print('Blob {} downloaded to {}.'.format(
source_blob_name,
destination_file_name))
#Se référer à la méthode d'acquisition de la liste de fichiers GCS
# https://cloud.google.com/storage/docs/listing-objects?hl=ja#storage-list-objects-python
def list_match_file_with_prefix(bucket_name, prefix, search_path):
"""Lists all the blobs in the bucket that begin with the prefix."""
storage_client = storage.Client()
# Note: Client.list_blobs requires at least package version 1.17.0.
blobs = storage_client.list_blobs(bucket_name, prefix=prefix, delimiter=None)
file_list = [blob.name for blob in blobs if search_path in blob.name]
return file_list
#Définir les variables d'environnement
BUDGET_NAME = os.environ['BUDGET_NAME']
PROBLEM = os.environ['PROBLEM']
DATA_DIR = os.environ['DATA_DIR']
TRAIN_DIR = os.environ['TRAIN_DIR']
HPARAMS = os.environ['HPARAMS']
MODEL = os.environ['MODEL']
#chemin du fichier de point de contrôle
src_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
dist_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
#Télécharger le fichier de point de contrôle depuis GCS
download_blob(BUDGET_NAME, src_file_name, dist_file_name)
#Dernière séquence de point de contrôle du fichier de point de contrôle(prefix)Obtenir
import re
with open(dist_file_name) as f:
l = f.readlines(1)
ckpt_name = re.findall('model_checkpoint_path: "(.*?)"', l[0])[0]
ckpt_path = os.path.join(TRAIN_DIR, ckpt_name)
#Obtenez la liste de fichiers associée au dernier point de contrôle de GCS
ckpt_file_list = list_match_file_with_prefix(BUDGET_NAME, TRAIN_DIR, ckpt_path)
# checkpoint.Téléchargez un ensemble de variables
for ckpt_file in ckpt_file_list:
download_blob(BUDGET_NAME, ckpt_file, ckpt_file)
#Obtenez le chemin du fichier de vocabulaire à partir de GCS
vocab_file = list_match_file_with_prefix(BUDGET_NAME, DATA_DIR, os.path.join(DATA_DIR, 'vocab'))[0]
#Télécharger le fichier de vocabulaire depuis GCS
download_blob(BUDGET_NAME, vocab_file, vocab_file)
Chargez le modèle de transformateur en fonction des résultats de la formation sur le transformateur téléchargés à partir de GCS.
#Initialisation
tfe = tf.contrib.eager
tfe.enable_eager_execution()
Modes = tf.estimator.ModeKeys
import pickle
import numpy as np
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
#Prétraitement&Utilisez le même nom de classe que PROBLE défini dans l'apprentissage
@registry.register_problem
class Translate_JPEN(text_problems.Text2TextProblem):
@property
def approx_vocab_size(self):
return 2**13
enfr_problem = problems.problem(PROBLEM)
# Get the encoders from the problem
encoders = enfr_problem.feature_encoders(DATA_DIR)
from functools import wraps
import time
def stop_watch(func) :
@wraps(func)
def wrapper(*args, **kargs) :
start = time.time()
print(f'{func.__name__} started ...')
result = func(*args,**kargs)
elapsed_time = time.time() - start
print(f'elapsed_time:{elapsed_time}')
print(f'{func.__name__} completed')
return result
return wrapper
@stop_watch
def translate(inputs):
encoded_inputs = encode(inputs)
with tfe.restore_variables_on_create(ckpt_path):
model_output = translate_model.infer(features=encoded_inputs)["outputs"]
return decode(model_output)
def encode(input_str, output_str=None):
"""Input str to features dict, ready for inference"""
inputs = encoders["inputs"].encode(input_str) + [1]
batch_inputs = tf.reshape(inputs, [1, -1, 1])
return {"inputs": batch_inputs}
def decode(integers):
"""List of ints to str"""
integers = list(np.squeeze(integers))
if 1 in integers:
integers = integers[:integers.index(1)]
return encoders["inputs"].decode(np.squeeze(integers))
hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)
translate_model = registry.model(MODEL)(hparams, Modes.PREDICT)
Inférer avec le modèle de transformateur chargé. Lorsqu'elle est exécutée localement, une phrase prend environ 30 secondes.
inputs = "My cat is so cute."
outputs = translate(inputs)
print(outputs)
résultat
>Mon chat est très mignon.