Source code for pimlico.datatypes.keras

"""
Datatypes for storing and loading Keras models.

"""
from __future__ import absolute_import
import os
from pimlico.core.dependencies.python import keras_dependency

from pimlico.datatypes.base import PimlicoDatatypeWriter, PimlicoDatatype
from pimlico.utils.core import import_member


[docs]class KerasModelWriter(PimlicoDatatypeWriter): """ Writer for storing both types of Keras model (since they provide the same storage interface). """ def __init__(self, base_dir, **kwargs): super(KerasModelWriter, self).__init__(base_dir, **kwargs) self.require_tasks("architecture", "weights") self.weights_filename = os.path.join(self.data_dir, "weights.hdf5")
[docs] def write_model(self, model): self.write_architecture(model) self.write_weights(model)
[docs] def write_architecture(self, model): # Store the model's architecture with open(os.path.join(self.data_dir, "architecture.json"), "w") as arch_file: arch_file.write(model.to_json()) self.task_complete("architecture")
[docs] def write_weights(self, model): # Store the model's weights model.save_weights(self.weights_filename, overwrite=True) self.task_complete("weights")
[docs]class KerasModel(PimlicoDatatype): """ Datatype for both types of Keras models, stored using Keras' own storage mechanisms. """ # Override to pass in extra values in Keras' custom objects arg to model_from_json # May be given as string fully-qualified Python names custom_objects = {}
[docs] def get_software_dependencies(self): return super(KerasModel, self).get_software_dependencies() + [keras_dependency]
[docs] def get_custom_objects(self): new_co = {} for name, cls in self.custom_objects.iteritems(): if isinstance(cls, basestring): # Import the class cls = import_member(cls) new_co[name] = cls return new_co
[docs] def load_model(self): from keras.models import model_from_json # Load the model architecture with open(os.path.join(self.data_dir, "architecture.json"), "r") as arch_file: model = model_from_json(arch_file.read(), custom_objects=self.get_custom_objects()) # Set the weights to those stored model.load_weights(os.path.join(self.data_dir, "weights.hdf5")) return model