Sunday, November 29, 2020

Summary of research paper "A Three-Player GAN: Generating Hard Samples To Improve Classification Networks" by Vandenhende

After reading this paper, I was confused by Algorithm 1 (page 2) as it did not really match parts of the text. So I contacted the author and would now like to clarify it for everyone else.

The paper describes a generative adversarial network which generates images that are hard to classify in order to augment a training set for a second classifier. This is done by attaching a classifier to the generator (apart from the discriminator) which learns to classify the generated data with a gradient reversal layer between the classifier and the generator. The gradient reversal layer forces the generator to generate examples which are difficult to classify by negating the gradient signal coming from the classifier loss before passing it to the generator. This makes the generator learn to maximise the classifier loss whilst the classifier is learning the minimise it, resulting in adversarial learning. The author does not give a reason for why the same kind of adversarial learning is not applied to both parts of the model.

What confused me is how would you know what class the difficult to classify examples produced by the generator would be when training the target classifier with the augmented training set. The reason I was confused is because I did not realise that this is not a GAN, but a CGAN, a conditional generative adversarial network. A CGAN works by feeding both the generator and the discriminator a class label. The label would be the label of the real data and a random label in generated data. Since the labels are the same for both the generator and the discriminator, the discriminator learns to discriminate the fake data by matching it to the label and seeing if it looks like what is expected given the label whilst the generator produce images which match the given label in order to fool the discriminator.

By combining this with the gradient reversal layer, the generated images would be both similar to real data of the given label and also difficult to classify. These two properties would be contradictory except for the fact that the effect of the classifier is scaled down due to the lambda hyperparameter of the gradient reversal layer, making the discriminator take precedence.