Universal Language Model Fine-tuning for Text Classification (ULMFiT)

Jump to: navigation, search

A Universal Language Model Fine-tuning for Text Classification (ULMFiT) is a Language Modeling System based on a Transfer Learning that can be used for Text Classification.




Figure 1: ULMFiT consists of three stages: a) The LM is trained on a general-domain corpus to capture general features of the language in different layers. b) The full LM is fine-tuned on target task data using discriminative fine-tuning (‘Discr’) and slanted triangular learning rates (STLR) to learn task-specific features. c) The classifier is fine-tuned on the target task using gradual unfreezing, ‘Discr’, and STLR to preserve low-level representations and adapt high-level ones (shaded: unfreezing stages; black: frozen).


(...) This idea has been tried before, but required millions of documents for adequate performance. We found that we could do a lot better by being smarter about how we fine-tune our language model. In particular, we found that if we carefully control how fast our model learns and update the pre-trained model so that it does not forget what it has previously learned, the model can adapt a lot better to a new dataset. One thing that we were particularly excited to find is that the model can learn well even from a limited number of examples. On one text classification dataset with two classes, we found that training our approach with only 100 labeled examples (and giving it access to about 50,000 unlabeled examples), we were able to achieve the same performance as training a model from scratch with 10,000 labeled examples.

Another important insight was that we could use any reasonably general and large language corpus to create a universal language model—something that we could fine-tune for any NLP target corpus. We decided to use Stephen Merity’s Wikitext 103 dataset, which contains a pre-processed large subset of English Wikipedia.

High level ULMFiT approach (IMDb example)