This will be a quick tip on how to use combine_adversarial_loss in tf.contrib.gan.estimator.GANEstimator. In my latest projects, I have been using TensorFlow estimators. Estimators allow you to focus more on creating models and wraps the whole training (including saving, exporting, and putting a model in production) into few lines. Recently, I experienced the limits of estimators when I wanted to train a generative adversarial network (GAN) with a combined adversarial loss. In this article, I will show you a little trick how to do that.

TF tricks #1 Using GANs

I don’t want to go too much into details of GANs, but here is little motivation behind this problem. I worked on the task where I have photos of textures and I wanted to recover dark areas (damaged by dust or dirt). Simple convolution neural network learned how fixed the color and preserve good parts, but it couldn’t generate the texture. The GAN approach to the training forces the network to generate the texture fix the problem. Since I wanted to generate texture only in damaged areas I used combined adversarial loss. This combines adversarial loss with standard CNN loss which forces the network to learn which areas should be preserved and which should be generated.

Solution

This solution was tested on TensorFlow r1.12. Reading through the code, tf.contrib.gan.losses.combine_adversarial_loss takes gan_loss tuple (discriminator and generator loss). It replaces generator loss with combine adversarial loss. That means that we need to replace generator_loss_fn in the estimator. The losses.combine_adv... provides a wrapper for losses.wargs.combine_adv... which actually returns the loss.

All other loss functions for estimator takes arguments: gan_model, **kwargs. We define our own function and use it as a generator loss function. In this function, we define adversarial and non-adversarial losses and combine them using combine_adversarial_loss. The **kwargs aren’t compatible with the adversarial loss function. From my experience, it only contains status of summaries, but it changes between versions.

Quick note: the weight_factor must be float number otherwise you will get an error.

Summary

I was always a little bit careful with using higher level API for training machine learning models. It always feels limiting to my ideas and experiments. However, understanding the inner-workings of the functions can give you the freedom and at the same time speed up the development process.

I was surprised by the development in the area of GAN networks. A few years ago when I first heard about GANs, they only produced small blurry images. Training a large GAN network is still extremely tricky without any correct approach, but with enough patience it can produce good results.