A Guide for Building GANs - 10 Tips and Tricks
After having focused on GANs exclusively for the last year and a half, I wanted to expand upon Soumith Chintala’s famous article How to Train a GAN? Tips and tricks to make GANs work and write down the 10 most important general GAN lessons I have learned.
I’ve ranked them in chronological order, so the post resembles a very detailed guide that you can follow along as you progress through your project.
Note that most of these tips apply to other ML models as well. They are generally good practices but especially important for GANs.
1) Get familiar with your problem and data
Before you do anything, always visualize your data and think carefully about what you want to do with it. This cannot be emphasized enough.
For instance, if you work on images, you should know:
- How many samples do you have?
- What is the resolution of your images?
- Is there anything special about your images (colors, objects, …)?
- Are there outliers or other potential problems?
- What do you want to do with your images?
- How hard is the task?
Only after you understand your problem and data well should you start thinking about solutions.
2) Know and understand existing work
Building GANs from scratch is extremely hard and can require many many hours of debugging and careful hyperparameter tuning to even get remotely acceptable results. Thus, you should always search for existing work first.
There are already papers using GANs for what you want to do? Great, use them or build upon them!
No one did what you want to do but people used GANs for related problems? Then try adjusting one of those methods to your task.
One important thing here is that you should understand very well whether a given model is suitable for your application or not. E.g., don’t use CycleGAN if you have paired data, don’t use StyleGAN if you have low-resolution images, …
How do you know whether a method is suitable? Read the paper! In particular, spend time understanding their method section and check their experiments. Try to understand the paper’s most important contribution and ask yourself whether that is important to you as well. Also, search for how other people used the method or how they expanded upon it in follow-up works.
3) If possible, use existing implementations and their hyperparameters
So you found a suitable existing method that you can use or build upon? Nice. Next, check if their implementation is open-source. If it is, use it!
Sometimes it can be tempting to reimplement methods on your own, especially if you plan to make substantial modifications or improvements later. However, as said before, building GANs is really hard and if you have any bug in your implementation the result might be totally off. Thus, even if the existing code is pure spaghetti, it is usually still faster to just refactor it. The only exception to this rule is if the original code is written in a language or framework you absolutely despise. In that case, go ahead and port it (and please publish your code later).
Also, on a similar note, always start out with the hyperparameters of the original paper/implementation. This will save you a lot of trouble since GANs are notoriously unstable w.r.t. hyperparameter choice.
Now, what do you do if there is no open-source implementation or even no suitable existing work at all? First, check again. If there really isn’t, you’ll have to implement and potentially design a GAN yourself. As I said, it’ll be painful, but it will also feel immensely satisfying once you get it to work, so don’t be discouraged!
4) Start with simple models
When you build models yourself, always start simple. In particular, start with a small architecture. A good rule of thumb is that your initial models should have less than a million parameters.
For GANs, there are additional non-obvious implications:
- Start with the simple standard GAN loss. Especially, don’t get baited by Wasserstein loss because it is supposedly more stable. Don’t get me wrong, WGAN is nice, but the main advantage is that your GAN will be more stable over time and that it will generally get better the longer you train it (which is often not the case for standard GAN). However, if your GAN produces only garbage, Wasserstein loss cannot fix that. In that case, it will only make everything worse because it is much harder to debug.
- Similarly, don’t use other stabilization tricks like spectral normalization, training schedules, etc. in the beginning. Again, if your GAN doesn’t work there is likely a bug somewhere, and adding those techniques will not magically make it work (however, similar to WGAN, they are nice to add later).
5) Start with simple data
In the beginning, keep not only the model as simple as possible but also the data. In particular, this means:
- Always start by overfitting a single sample. If that doesn’t work, you have a major bug in your model for sure, most likely somewhere in the loss function definition.
- If you can overfit a single sample, try a minibatch. Now your model can’t just memorize a fixed output anymore and needs to start paying attention to the input. If that doesn’t work you probably have a bug in data loading, so visualize your inputs and check if they make sense. If that’s not it, it can also be caused by vanishing/exploding gradients sometimes.
- Once you can overfit a minibatch, your model is most likely OK (In other ML models you want to overfit the whole train set first before trying to generalize, but with GAN’s overfitting is usually not an issue ,so you can skip this step).
6) Do not forget to freeze weights
One of the most annoying bugs in GANs is if you forgot to freeze weights. As you know, discriminator and generator should be trained in alternating fashion, which means that you have to freeze one model while training the respective other (in PyTorch: detach() outputs or set params to requires_grad=False).
This can be easily overlooked and it is an extremely hard-to-find bug. Sometimes, the GAN can even still produce acceptable outputs if freezing has only been omitted for one part, but the training will be much slower then, and it will not lead to great results in the end.
Thus, always check twice whether you freeze/unfreeze your generator/discriminator at the correct times.
7) Prevent large weight updates in the discriminator
Another common problem is a too high learning rate for the discriminator. As mentioned before, GANs are very sensitive to hyperparameters, and the discriminator learning rate is the most important one. If the discriminator learning rate is too high, the models will have immediate mode collapse and nothing will be learned at all.
My recommendation: start with a learning rate around 1e-5 and, if training is too slow, start increasing it from there.
8) Make sure all components work individually
Even if you start simple and took care of obvious issues, your GAN will likely still have some bugs in the beginning. Because a GAN is a complex system of components, problems can often come from various places, so the best way to debug it is by breaking things up into smaller components and testing those separately (a kind of unit testing if you want).
Some checks I usually do:
- Can the generator be used as an autoencoder with L1 reconstruction loss?
- Can the discriminator be trained on a classification task and achieve satisfying accuracy?
- For encoder-decoder architectures, can the encoder be used for classification, too?
- For multi-GAN architectures (like CycleGAN), can each GAN learn unconditional distributions on its own?
- For complex tasks (like video generation), can your method do simpler tasks (like image generation)?
By breaking down components like this and training them on all kinds of other ML tasks, you can narrow down on the issues and find out precisely where you need to search.
9) Log everything
Since GANs are so unstable, reproducibility is key. If even just one detail is slightly off, results can be vastly different, so you need to meticulously log everything.
You can use any tool you want for this. Recently, many people use Weights & Biases, it’s probably a good tool to learn. Personally, I always use plain old TensorBoard and log:
- hyperparameters (TensorBoard had special functionality for that, use it!)
- git commit (very important; allows you to revert to the exact code used)
- loss curves (as detailed as possible; if you have composite losses, log everything separately too, it’s super useful for finding potential issues)
- model outputs (images / videos / …)
- evaluation metrics (see next tip)
10) “Look at metrics, not images” - Timo Aila
Finally, I want to conclude with this quote of Timo Aila, one of the leading researchers behind StyleGAN:
Look at metrics, not images.
It’s what helped me tremendously during my thesis on video generation.
The issue with GANs here is that it’s very tempting to get overly focused on qualitative results because we humans are very good at visual perception and can often intuitively tell whether something looks real or not.
However, you can only look at a small sample of outputs and your rating will likely be biased towards the model that produced the one output you liked best. Thus, you make suboptimal decisions if you only look at qualitative results.
In the end, you will want conduct proper quantitative comparisons anyways, so just define your metrics early on and use them for model selection (just like you would with any other ML model that is not a GAN).