J'ai fait référence à certains sites pour utiliser Chainer.
Cependant, lorsque j'ai exécuté train_imagenet.py pour apprendre ma propre image, l'erreur suivante s'est produite.
Erreur
cPickle.UnpicklingError: invalid load key,
La partie correspondante est un traitement non-Pickle par la fonction appelée pickle.load du code de ↓
train_imagenet.py
# Prepare dataset
train_list = load_image_list(args.train, args.root)
val_list = load_image_list(args.val, args.root)
mean_image = pickle.load(open(args.mean, 'rb'))
La valeur de l'argument args.mean est un fichier appelé mean.npy, donc si vous recherchez la source de ce fichier ...
compute_mean.py
#!/usr/bin/env python
import argparse
import os
import sys
import numpy
from PIL import Image
import six.moves.cPickle as pickle
parser = argparse.ArgumentParser(description='Compute images mean array')
parser.add_argument('dataset', help='Path to training image-label list file')
parser.add_argument('--root', '-r', default='.',
help='Root directory path of image files')
parser.add_argument('--output', '-o', default='mean.npy',
help='path to output mean array')
args = parser.parse_args()
sum_image = None
count = 0
for line in open(args.dataset):
filepath = os.path.join(args.root, line.strip().split()[0])
image = numpy.asarray(Image.open(filepath)).transpose(2, 0, 1)
if sum_image is None:
sum_image = numpy.ndarray(image.shape, dtype=numpy.float32)
sum_image[:] = image
else:
sum_image += image
count += 1
sys.stderr.write('\r{}'.format(count))
sys.stderr.flush()
sys.stderr.write('\n')
mean = sum_image / count
pickle.dump(mean, open(args.output, 'wb'), -1)
Il semble que l'objet créé par la fonction numpy.ndarray est sorti dans un fichier appelé mean.npy par la fonction pickle.dump. En d'autres termes, l'entité de mean.npy est comme un flux d'octets d'un tableau NumPy.
Donc, au lieu de lire mean.npy comme non-Pickle dans train_imagenet.py, je l'ai modifié pour le lire comme un tableau NumPy.
train_imagenet.py
# Prepare dataset
train_list = load_image_list(args.train, args.root)
val_list = load_image_list(args.val, args.root)
# mean_image = pickle.load(open(args.mean, 'rb'))← cPickle lorsqu'il est lu comme non-Pickle.UnpicklingError
mean_image = np.load(args.mean) #Lire comme un tableau NumPy
Puis j'ai réussi à le lire.