I get an error when trying to save KerasRegressor with pickle or joblib, How to make it saveable.
Monkey patch Keras Regressor below
def KerasRegressor__getstate__(self):
result = { 'sk_params': self.sk_params }
with tempfile.TemporaryDirectory() as dir:
if hasattr(self, 'model'): #There are cases where it does not exist due to cloning by the parent Estimator etc.
self.model.save(dir + '/output.h5', include_optimizer=False)
with open(dir + '/output.h5', 'rb') as f:
result['model'] = f.read()
return result
KerasRegressor.__getstate__ = KerasRegressor__getstate__
def KerasRegressor__setstate__(self, serialized):
self.sk_params = serialized['sk_params']
with tempfile.TemporaryDirectory() as dir:
model_data = serialized.get('model')
if model_data:
with open(dir + '/input.h5', 'wb') as f:
f.write(model_data)
self.model = models.load_model(dir + '/input.h5')
KerasRegressor.__setstate__ = KerasRegressor__setstate__
__getstate__, __setstate__
Can be used to customize the serialization and deserialization of pickle for each class.(For details)
Recommended Posts