How ChatGPT is fine-tuned using Reinforcement Learning
Thanh Long Phan
At the end of 2022, OpenAI released ChatGPT (a Transformer-based language model) to the public. Although based on the already widely discussed GPT-3, it launched an unprecedented boom in generative AI.
It is capable of generating human-like text and has a wide range of applications, including language translation, language modeling, and generating text for applications such as chatbots. Feel free to also read our introduction to LLMs.
ChatGPT seems to be so powerful that many people consider it to be a substantial step towards artificial general intelligence.
The main reason for the recent successes of language models such as ChatGPT lies in their size (in terms of trainable parameters). But making language models bigger does not inherently make them better at following a user's intent. A bigger model can also become more toxic and more likely to "hallucinate". To mitigate these issues and to more generally align models to user intentions, one option is to apply Reinforcement Learning.
In this blog post, we will present an overview of the training process of ChatGPT, and have a closer look at the use of Reinforcement Learning in language modeling. Also interesting: Our aggregated collection of LLM content.
What is a language model?
A language model is a probability distribution over sequences of words. By analyzing a large text corpus, the model is able to understand language in a mathematical sense. So given a sequence of words of length $$m$$, a language model assigns a probability $$P(w_0,...,w_{m-1})$$ to the whole sequence. So in order to generate a text answering an input prompt, the model generates the next word token depending on its probability given the previous words:
where $$w_i$$ is the token at position $$i$$ in the sequence and $$x$$ is the input prompt. So the model assigns for each word a probability that it comes next.
How does a language model generate different outputs from the same input?
To achieve this, multiple strategies have been developed. One common strategy here is called TOP-K. So let's say the model predicts the probability for the next word, then it chooses the k tokens with the highest probabilities and samples from them. So every time the model can generate a different text.
An overview of Reinforcement Learning
Compared to supervised learning, reinforcement learning is a type of learning that is based on the interaction of an agent with the environment. The objective of the AI agent is to maximize a single scalar called the reward when following a policy $$\pi$$.
Like any other machine learning setup, we define a set of parameters $$\theta$$ to be the weights and biases of a neural network to parametrize this policy $$\pi$$. Mathematically, the goal of reinforcement learning is to maximize the expected reward by following a parametrized policy $$\pi$$:
where $$\tau$$ is a sequence of (state, action) pairs of an agent in the environment and $$r(\tau)$$ is the total reward of the trajectory $$\tau$$. Like any machine learning problem, if we can find the parameter $$\hat{\theta}$$ such that
then we will have a model, which is able to solve the task. A standard approach to solve this maximization problem is to use Gradient Ascent, so similar to the Gradient Descent we update the parameter by:
where $$\alpha$$ is a suitable learning rate.
One problem in reinforcement learning is creating a reward function. It is still a problem to know how much reward should be given to the agent after each action. So often designing a reward function is a trial-and-error engineering process.
Why do we use reinforcement learning in language modeling?
With a language model, we can generate impressive text from human input prompts. But how do we define an answer of the language model to be 'good'? Moreover, we do not want our models to imitate humans, we want them to produce high-quality answers, such that the generated text is less toxic, for example.
Writing a loss function to capture these attributes seems unsolvable. So, instead, we can use human feedback for the generated text as a measurement of performance, which is then used to optimize the model. This idea is called Reinforcement Learning from human Feedback (RLHF).
The training process of ChatGPT
We introduce the training process of a language model using reinforcement learning in three steps:
Step 1: Collect demonstration data and train a supervised policy.
Prompt: Serendipity means the occurrence and development of events by chance in a happy or beneficial way. Use the word in a sentence.
Labeler: Running into Margaret and being introduced to Tom was a fortunate stroke of serendipity.
In the first step, we need to collect data and train a supervised policy.
Human trainers provide conversations in which they play both sides: they play the role of the ideal AI assistant as well as the role of the user. Moreover, they also have access to answers suggested by the model, so that they can compose their responses. Then a pretrained transformer-based model is fine-tuned on this dataset combined with the old dataset, which is transformed into a dialogue format.
The training tasks are from two sources:
a dataset of prompts written by the labelers,
a dataset of prompts submitted to early states of the models on the API.
These prompts are very diverse and include generation, question answering, dialog, summarization, extractions and other natural language tasks. For each prompt, the task is most often specified directly through a natural language instruction, but could also be indirectly through either few-shot examples or implicit communication. In each case, the labelers are asked to do their best to infer the intent of the user who wrote the prompt and to skip the inputs where the task is very unclear. Moreover, the labelers also take into account the implicit intentions such as the truthfulness of the response and potentially harmful outputs.
The group of labelers was chosen carefully, they were sensitive to the preferences of different demographic groups, and they were good at identifying outputs that were potentially harmful.
Step 2: Collect comparison data and train a reward model.
As we noted above, the reward function is hard to design. The underlying goal of the second step is to get a model that takes in a pair (prompt, text) and returns a scalar reward which should numerically represent the human preference. This model is an approximation of the actual reward function.
We take the fine-tuned model from the first step and this model generates $$k$$ text samples to a given input prompt. Then a human labeler will rank the generated samples from best to worst. One may think that humans should apply a scalar score directly to each piece of text, but this is a difficult problem since each labeler can have different preferences. Ranking $$k$$ responses, we get $$\binom{k}{2}$$ pairs which can be compared. Since the comparisons are very correlated within each labeling task, if we simply shuffled the comparisons into one dataset, a single pass over the dataset would cause the reward model to overfit. Instead, the reward model is trained on all $$\binom{k}{2}$$ comparisons from each prompt as a single batch element. This is much more computationally efficient because it requires a single forward pass of the reward model. And because the model is no longer overfitting, it achieves improved validation accuracy and log loss.
We define the loss function for the reward model:
where $$r_\theta(x,y)$$ is the scalar output of the reward model for prompt $$x$$ and completion $$y$$, $$y_w$$ is the preferred completion out of the pair of $$y_w$$ and $$y_l$$, and $$D$$ is the dataset of human comparisons. $$\sigma$$ denotes the sigmoid function. Hence, if the reward for the completion $$y_w$$ is higher than the reward for $$y_l$$, then the loss is "small".
The supervised fine-tuned model with the final unembedding layer replaced takes in a prompt and a response and outputs a scalar reward. It can be trained as a reward model.
Step 3: Optimize a policy against the reward model using a PPO Reinforcement Learning algorithm.
Then using the reward function $$r_\theta$$ from the second step, we define the objective for the third step:
where $$\pi\phi^{RL}$$ is the learned RL policy and $$\pi^{SFT}$$ is the supervised fine-tuned model from the first step. The Kullback-Leibner (KL) reward coefficient $$\beta$$ controls the strength of the KL penalty. We can think of the second term with $$\beta$$ as a regularizer, it makes $$\pi\phi^{RL}$$ to stay close to $$\pi^{SFT}$$. And $$\gamma$$ is the pretraining loss coefficient, which controls the strength of the KL penalty and pretraining gradients respectively. $$D_{pretrain}$$ denotes the supervised dataset used to train the SFT model, and $$D_{\pi\phi^{RL}}$$ denotes the PPO dataset, which does not contain any human labels, this dataset contains the prompts from the labelers. We can think of the $$D_{\pi_\phi^{RL}}$$ as the environment, where we want our agent to freely explore and learn.
The goal for this step is to find
The model then will be updated using Gradient Ascent in several iterations of this process. Steps 2 and 3 can be iterated continuously, more comparison data is collected on the current best policy, which is used to train a new reward model and then a new policy.
Results
After fine-tuning GPT-3 using Proximal Policy Optimization (PPO), OpenAI called it InstructGPT. The outputs of InstructGPT are given much higher scores by the labelers than the outputs from GPT-3, despite it having 100 times fewer parameters than GPT-3:
Compared to GPT-3, InstructGPT models show improvement in truthfulness and also show small improvements in toxicity over GPT-3:
Fine-tuning language models with humans in the loops has shown to be able to create more reliable language models.
For further reading on reliable langue models, I recommend one of our blog posts about Ethics in Natural Language Processing.
Fine-tuning LLMs with the help of reinforcement learning is not trivial, which is why it may make sense to talk to an AI company such as dida if you are considering customizing your own models for your company.
At this point, we would also like to refer you to our contact options, within which we are happy to advise you on LLM topics.