Thursday, June 28, 2018

Saving/Loading a Tensorflow model using HDF5 (h5py)

The normal way to save the parameters of a neural network in Tensorflow is to create a tf.train.Saver() object and then calling the object's "save" and "restore" methods. But this can be a bit cumbersome and you might prefer to have more control on what and how things get saved and loaded. The standard file format for saving large tensors (such as the parameters of a neural network) is to use HDF5.

Saving is easy. Here is the code you'll need:
with h5py.File('model.hdf5', 'w') as f:
    for var in tf.trainable_variables():
        key = var.name.replace('/', ' ')
        value = session.run(var)
        f.create_dataset(key, data=value)

Notice that we need to use the session in order to get a parameter value.

If you're using variable scopes then your variable names will have slashes in them and here we're replacing slashes with spaces. The reason is because the HDF5 format treats key values as directories where folder names are separated by slashes. This means that you need to traverse the keys recursively in order to arrive at the data (one folder name at a time) if you do not know the full name at the start. This replacement of slashes simplifies the code for loading a model later.

Notice also that you can filter the variables to save as you like as well as save extra stuff. I like to save the Tensorflow version in the file in order to be able to check for incompatible variable names in contrib modules (RNNs had some variable names changed in version 1.2).

Now comes the loading part. Loading is a tiny bit more involved because it requires that you make you neural network code include stuff for assigning values to the variables. All you need to do whilst constructing your Tensorflow graph is to include the following code:
param_setters = dict()
for var in tf.trainable_variables():
    placeholder = tf.placeholder(var.dtype, var.shape, var.name.split(':')[0]+'_setter')
    param_setters[var.name] = (tf.assign(var, placeholder), placeholder)

What this code does is it creates separate placeholder and assign nodes for each variable in your graph. In order to modify a variable you need to run the corresponding assign node in a session and pass the value through the corresponding placeholder. All the corresponding assign nodes and placeholders are kept in a dictionary called param_setters. We're also naming the placeholder the same as the variable but with '_setter' at the end.

Notice that param_setters is a dictionary mapping variable names to a tuple consisting of the assign node and the placeholder.

Now we can load the HDF5 file as follows:
with h5py.File('model.hdf5', 'r') as f:
    for (name, val) in f.items()
        name = name.replace(' ', '/')
        val = np.array(val)
        session.run(param_setters[name][0], { param_setters[name][1]: val })

What's happening here is that we're loading each parameter from the file and replacing the spaces in names back into slashes. We then run the corresponding assign node for the given parameter name in param_setters and set it to the loaded value.