Fake it till you make it: GANs with HDF5

Hi, folks!

I wanted to share something with you today: a recipe to generate different synthetic data each time you read a given HDF5 dataset.

If you have ever looked into the TensorFlow Keras API, you may have noticed that it is possible to train a machine learning model and save the trained model weights to an HDF5 file. Afterwards, one can load that trained model and use it for e.g., inference.

GANs became a popular architecture for the creation of synthetic datasets. The architecture comprises a generator model (which attempts to generate realistic data) and a discriminator (which classifies that generated data as fake or real). And again, once trained, the generator model can be serialized to HDF5 – as shown in the pseudo-code below:

def train(...):
    for ...
        discriminator.train(real_data)
        discriminator.train(generator.gen_fake_data())
        loss = gan.train()
        ...
    generator.save('generator.h5')

The Keras API to load that model and generate new synthetic data is just as simple:

from keras.models import load_model
model = load_model('generator.h5')
synthetic_data = model.predict(...)

Now, here is where the interesting part begins! :upside_down_face:

Since we already have the model as HDF5, we can embed the logic from the previous snippet on a dataset. That way, each time the dataset is opened and read, we seamlessly execute code that loads the model, generates new synthetic data, and populates the dataset values according to that data!

We begin by creating a UDF Python script and moving the logic from the previous snippet into the function executed by HDF5-UDF:

def dynamic_dataset():
    # First, we load the required modules
    from keras.models import load_model
    from numpy.random import randn

    # Then, we use HDF5-UDF's introspection API to retrieve the path to the
    # HDF5 file and provide that path to Keras' load_model() function
    path = lib.getFilePath().decode('utf-8')
    model = load_model(path)

    # Next, we generate a new random image with model.predict() and using the
    # shape of the data used to train the network
    latent_points = randn(280).reshape(1, 280)
    X = model.predict(latent_points)[0,:]

    # Last, we initialize the values of our`synthetic_data with the contents
    # of the generated image. HDF5-UDF's `getData()` and `getDims()` APIs are
    # used to get a pointer to the dataset buffer and its registered dimensions,
    # respectively
    udf_data = lib.getData("synthetic_data")
    udf_dims = lib.getDims("synthetic_data")
    size = udf_dims[0] * udf_dims[1]
    udf_data[0:size] = X.flatten()[:]

The UDF file (here called gan.py) can be compiled and associated with a new dataset by running the following command:

$ hdf5-udf generator.h5 gan.py synthetic_data:280x280:float

Now, every time an application loads the synthetic_dataset it gets a brand new synthetic data generated by the model built into that very file!

As you can see from the image above, the GAN generated an image that looks like a shoe – which is a good thing, after all this is a sample network trained on the Fashion MNIST dataset :grin:

I have mentioned this proof of concept very briefly on the last European HDF5 Users Group Meeting, but I thought it deserved a more detailed post to state what was going on under the hoods.

Cheers,
Lucas

4 Likes