J'obtiens une erreur lorsque j'essaye de sauvegarder KerasRegressor avec pickle ou joblib, Comment le rendre sauvegardable.
Monkey patch Keras Regressor ci-dessous
def KerasRegressor__getstate__(self):
result = { 'sk_params': self.sk_params }
with tempfile.TemporaryDirectory() as dir:
if hasattr(self, 'model'): #Il y a des cas où il n'existe pas en raison d'un clonage par l'estimateur parent, etc.
self.model.save(dir + '/output.h5', include_optimizer=False)
with open(dir + '/output.h5', 'rb') as f:
result['model'] = f.read()
return result
KerasRegressor.__getstate__ = KerasRegressor__getstate__
def KerasRegressor__setstate__(self, serialized):
self.sk_params = serialized['sk_params']
with tempfile.TemporaryDirectory() as dir:
model_data = serialized.get('model')
if model_data:
with open(dir + '/input.h5', 'wb') as f:
f.write(model_data)
self.model = models.load_model(dir + '/input.h5')
KerasRegressor.__setstate__ = KerasRegressor__setstate__
__getstate__, __setstate__
Peut être utilisé pour personnaliser la sérialisation et la désérialisation de pickle pour chaque classe.(Pour plus de détails)
Recommended Posts