Estimating the Q function with GANs
I recently happened upon quite an interesting paper, titled “GAN Q Learning”. I’ve spent some time playing around with the algorithm, and, during that time, I’ve gained valuable insight into the inner workings of both Deep Learning Paradigms. But, before I can give a full explanation, I’ll give a background introduction to both.
A Quick Review of GANs
There are a ton of online tutorials on GANs, but they can be boiled down to two separate parts: the discriminator and the generator. They can be explained in one great analogy: the generator is a counterfeit money operation, seeking to pass off dollar bills, while the discriminator is a federal agency, seeking to stop the generator from getting away with it. In more precise terms, the generator seeks to artificially copy some data distribution using a random seed z, while the discriminator helps determine the optimal features to copy.
While there are some amazing results that come from GANs, these cherry-picked examples come with a whole host of training issues. Fundamentally, the GAN training paradigm (which is based on basic game theory) breaks down in one of three different ways
The generator becomes too strong, meaning that the discriminator will never be able to tell real and fake images apart. Note that this indicates that the discriminator is too simple.
The discriminator becomes too strong, meaning that the generator becomes stuck and can’t improve. This often happens again when the discriminator is too simple.
The generator ignores the random seed and nothing of value is produced (might be the same image or just a blur).
Most recent work on GANs have tried to alleviate these training difficulties.
A Quick Review of Q-Learning
Reinforcement learning is a branch of machine learning that ties in concepts from game theory and more conventional applied math. The fundamental problem it tries to solve is that of the Markov Decision Process (MDP). In a MDP, an agent (robot, human, animal) tries to maximize the long term reward of doing a set of actions in a set environment. For example, and MDP could be a
Robot trying to follow a path in the sand
Human trying to maximize his career prospects
Animal trying to survive
Q learning is a technique in reinforcement learning that estimates a Q function, which is essentially the future total reward when taking an action. A very common example of this principle is shown in the Stanford Marshmellow Experiment, in which the optimal solution one should take is to wait for extra marshmellow, even though it is initially unbeneficial. Q learning has been utilized in many famous projects such as DeepMind’s famous Atari demonstration and is a staple of reinforcement learning.
GAN Q-Learning Algorithm
Now, both of these subjects come from wildly different areas, one with generative modelling and the other from reinforcement learning, but as previous work on algorithms such as GAIL has shown, generative models have great promise in reinforcement learning. In this case, the generative portion of the GAN architecture estimates the Q function, which the agent acts on.
This algorithm is beneficial in the sense that it allows for more complex modelling of system dynamics, which could prove to be more useful than the traditional Q learning methods. This is due to the fact that the generative model learns the distribution of the Q-function rather than simply being trained on examples. Furthermore, the paper shows that, for some tasks, it achieves a better result than Q-learning as a consequence of this ability to model more complex reward functions.
However, this algorithm is incredibly difficult to train. Q-learning in and of itself is not guaranteed to converge, and GANs have a whole assortment of issues plaguing their viability. In my attempts to recreate these results on the famous OpenAI Gym Cartpole environment, I found that most network specifications failed almost instantly. In the end, I have not yet been able to get this algorithm to converge.
My Final Thoughts
It’s interesting to see the usage of neural network models in reinforcement learning, especially generative models. Whether it be GANs, VAEs, or even simple CNNs, I look forward to seeing similar stuff in the future. However, for this paper especially, some more mathematics to explain how the model converges would add another dimension to this paper that it could use. Overall, this is an interesting read, and I wonder if any similar algorithms will improve on this in the near future.
My implementation of the algorithm can be found here.