It has been a long time since I posted a tech post!
I’ve been so busy pursuing my graduate studies, job hunting, and Ph.D. applications, so I haven’t had any chance to invest sufficient time to post a technical post.
Since my last semester has ended and I have a bit of time before starting a new chapter, I want to make a simple post on how to serve a deep neural network into an API using the FastAPI[1] library.
Let’s go through it.
Implementation of the model prediction
First, we should choose which model should be served.
In this post, I’m going to use a fine-tuned GPT-2 model, which I already introduced in the previous post because it is easy to use, intuitive, and not too large to be loaded to my local GPU.
For ease of usage, I uploaded the trained model to my Hugging Face Hub repository.
You can see the model in this link.
Same as other pre-trained models in the Hub, we need only a few lines to load the tokenizer and fine-tuned model using .from_pretrained
function supported by Transformers library.
Still, the text input should be preprocessed in the same format used for fine-tuning.
Let’s take each step for implementing the inferencing logic by refactoring the original code here[2].
from itertools import chain
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
class Inferencer():
def __init__(self, args):
self.chat_history = {}
# Setting the GPU.
if torch.cuda.is_available() and isinstance(args.gpu, int):
self.device = torch.device(f"cuda:{args.gpu}")
else:
self.device = torch.device("cpu")
# Setting the tokenizer.
self.tokenizer = GPT2Tokenizer.from_pretrained(args.model_path)
special_tokens = self.tokenizer.special_tokens_map
self.bos_token = special_tokens['bos_token']
self.eos_token = special_tokens['eos_token']
self.sp1_token = special_tokens['additional_special_tokens'][0]
self.sp2_token = special_tokens['additional_special_tokens'][1]
vocab = self.tokenizer.get_vocab()
self.vocab_size = len(vocab)
self.bos_id = vocab[self.bos_token]
self.eos_id = vocab[self.eos_token]
self.sp1_id = vocab[self.sp1_token]
self.sp2_id = vocab[self.sp2_token]
# Setting the model.
self.model = GPT2LMHeadModel.from_pretrained(args.model_path).to(self.device)
self.model.eval()
# Decoding parameters.
self.max_turns = args.max_turns
self.max_len = self.model.config.n_ctx
self.top_p = args.top_p
self.temperature = args.temperature
self.num_beams = args.num_beams
...
The constructor is pretty straightforward, setting the GPU, model, tokenizer, special tokens we used for fine-tuning, and the hyperparameters for decoding.
All member attributes in args are set freely.
The details of the special tokens are elaborated in the previous post on fine-tuning.
Also, while I actually implemented the nucleus sampling[3] from scratch before, this time we are going to use the generation function by Transformers library.
So I just defined a few more parameters other than top_p, such as temperature and num_beams.
Also, I defined self.chat_history
dictionary to memorize the chat history of each user.
Since this inferencer will be shared, the records should be entirely separated between the users.
Of course, this implementation design can be extended, for example, using a multi-threaded server, each thread takes one user connection and it has its own inferencer object which cannot be compromised by other users.
However, for simplicity, we are going to use just one shared inferencer in this post.
To add a new message from a user, we can have add_history
function as below.
# Adding a new message.
def add_message(self, user_id, speaker_id, message):
if user_id not in self.chat_history:
self.chat_history[user_id] = []
self.chat_history[user_id].append((speaker_id, message))
Finally, I implemented the prediction function to process each input message and generate a new response from the model:
# A single prediction.
async def predict(self, user_id):
input_hists = []
for tup in self.chat_history[user_id]:
token_ids = [self.sp1_id if tup[0] == 1 else self.sp2_id] + self.tokenizer.encode(tup[1])
input_hists.append(token_ids)
# Adjusting the length.
if len(input_hists) >= self.max_turns:
num_exceeded = len(input_hists) - self.max_turns + 1
input_hists = input_hists[num_exceeded:]
# Setting the input ids and type ids.
input_ids = [self.bos_id] + list(chain.from_iterable(input_hists)) + [self.sp2_id]
start_sp_id = input_hists[0][0]
next_sp_id = self.sp1_id if start_sp_id == self.sp2_id else self.sp2_id
token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(input_hists)]
token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [self.sp2_id]
input_len = len(input_ids)
input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)
token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device)
# Getting the output.
output_ids = self.model.generate(
input_ids=input_ids,
token_type_ids=token_type_ids,
pad_token_id=self.eos_id,
do_sample=True,
top_p=self.top_p,
max_length=self.max_len,
num_beams=self.num_beams,
temperature=self.temperature,
output_hidden_states=True,
output_scores=True,
return_dict_in_generate=True
).sequences
output_ids = output_ids[0].tolist()[input_len:]
res = self.tokenizer.decode(output_ids, skip_special_tokens=True)
self.chat_history[user_id].append((2, res))
return res
As mentioned before, this is a refactored version of the original repository.
The difference here is that we use model.generate
function and we process the messages in the whole record for every prediction, not processing beforehand.
Integration with FastAPI
This is the main part, which we implement the FastAPI endpoints for the actual usage.
First, we define the data format using Pydantic[4] to parse the correct types of data.
from pydantic import BaseModel
# Definition of a basic text input.
class TextInput(BaseModel):
user_id: str
message: str
Here, I defined user_id
which is a unique string per each different user and message
which is the actual text message a user typed.
Using this model, we can easily validate the data given in the body of the POST request.
Now, we are going to add the FastAPI object and a few endpoints to process HTTP requests.
from fastapi import FastAPI
app = FastAPI()
# Default page.
@app.get("/")
def index():
return {'message': "Welcome to the basic GPT2 chit chat API!"}
# Posting one user message.
@app.post("/infer")
async def infer(data: TextInput):
data = data.dict()
user_id = data['user_id']
message = data['message']
inferencer.add_message(user_id, 1, message)
response = await inferencer.predict(user_id)
return {'message': response}
# Running the server.
uvicorn.run(app, host='127.0.0.1', port=args.port)
First, I made a FastAPI object called app
.
Using the @app
decorator, we can add a function to process an HTTP request and return the corresponding result.
Here, @app.get("/")
takes the default endpoint and just shows a welcoming message.
@app.post("/infer")
is the main function to process one interaction between the user and the model.
When a POST request with the endpoint /infer
arrives, this function parses the data to get user_id
and message
from the request body, adds a new user message in the chat history initialized in the inferencer, and receives a response predicted from the model.
In addition, I defined infer
as an asynchronous to have await
for each prediction, since one generation might take a little bit long considering the communication between the RAM and GPU memory.
While it does not matter in a simple example like this, in an actual Web application, it is important to consider the asynchronous processing when the actual prediction takes long enough to affect other functionalities.
One advantage of FastAPI is that we don’t need an additional library such as asyncio
to use an asynchronous function since FastAPI does it for us.
Ultimately, the full file serve_gpt2_fastapi.py
looks like:
from itertools import chain
from pydantic import BaseModel
from fastapi import FastAPI
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import uvicorn
import torch
import argparse
# Definition of a basic text input.
class TextInput(BaseModel):
user_id: str
message: str
# Definition of the main inferencer class.
class Inferencer():
def __init__(self, args):
self.chat_history = {}
# Setting the GPU.
if torch.cuda.is_available() and isinstance(args.gpu, int):
self.device = torch.device(f"cuda:{args.gpu}")
else:
self.device = torch.device("cpu")
# Setting the tokenizer.
self.tokenizer = GPT2Tokenizer.from_pretrained(args.model_path)
special_tokens = self.tokenizer.special_tokens_map
self.bos_token = special_tokens['bos_token']
self.eos_token = special_tokens['eos_token']
self.sp1_token = special_tokens['additional_special_tokens'][0]
self.sp2_token = special_tokens['additional_special_tokens'][1]
vocab = self.tokenizer.get_vocab()
self.vocab_size = len(vocab)
self.bos_id = vocab[self.bos_token]
self.eos_id = vocab[self.eos_token]
self.sp1_id = vocab[self.sp1_token]
self.sp2_id = vocab[self.sp2_token]
# Setting the model.
self.model = GPT2LMHeadModel.from_pretrained(args.model_path).to(self.device)
self.model.eval()
# Decoding parameters.
self.max_turns = args.max_turns
self.max_len = self.model.config.n_ctx
self.top_p = args.top_p
self.temperature = args.temperature
self.num_beams = args.num_beams
# Adding a new message.
def add_message(self, user_id, speaker_id, message):
if user_id not in self.chat_history:
self.chat_history[user_id] = []
self.chat_history[user_id].append((speaker_id, message))
# A single prediction.
async def predict(self, user_id):
input_hists = []
for tup in self.chat_history[user_id]:
token_ids = [self.sp1_id if tup[0] == 1 else self.sp2_id] + self.tokenizer.encode(tup[1])
input_hists.append(token_ids)
# Adjusting the length.
if len(input_hists) >= self.max_turns:
num_exceeded = len(input_hists) - self.max_turns + 1
input_hists = input_hists[num_exceeded:]
# Setting the input ids and type ids.
input_ids = [self.bos_id] + list(chain.from_iterable(input_hists)) + [self.sp2_id]
start_sp_id = input_hists[0][0]
next_sp_id = self.sp1_id if start_sp_id == self.sp2_id else self.sp2_id
token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(input_hists)]
token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [self.sp2_id]
input_len = len(input_ids)
input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)
token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device)
# Getting the output.
output_ids = self.model.generate(
input_ids=input_ids,
token_type_ids=token_type_ids,
pad_token_id=self.eos_id,
do_sample=True,
top_p=self.top_p,
max_length=self.max_len,
num_beams=self.num_beams,
temperature=self.temperature,
output_hidden_states=True,
output_scores=True,
return_dict_in_generate=True
).sequences
output_ids = output_ids[0].tolist()[input_len:]
res = self.tokenizer.decode(output_ids, skip_special_tokens=True)
self.chat_history[user_id].append((2, res))
return res
app = FastAPI()
# Default page.
@app.get("/")
def index():
return {'message': "Welcome to the basic GPT2 chit chat API!"}
# Posting one user message.
@app.post("/infer")
async def infer(data: TextInput):
data = data.dict()
user_id = data['user_id']
message = data['message']
inferencer.add_message(user_id, 1, message)
response = await inferencer.predict(user_id)
return {'message': response}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int, default=8000, help="The port number.")
parser.add_argument('--model_path', type=str, required=True, help="The path to the model in HuggingFace Hub.")
parser.add_argument('--gpu', type=int, default=0, help="The index of GPU to use.")
parser.add_argument('--max_turns', type=int, default=5, help="The maximum number of dialogue histories to include.")
parser.add_argument('--top_p', type=float, default=1.0, help="The top p value for nucleus sampling.")
parser.add_argument('--temperature', type=float, default=1.0, help="The temperature value.")
parser.add_argument('--num_beams', type=int, default=1, help="The number of beams for beam search.")
args = parser.parse_args()
# Initializing the inferencer.
inferencer = Inferencer(args)
# Running the server.
uvicorn.run(app, host='127.0.0.1', port=args.port)
How does it look?
Now, let’s see how it works by testing the implementation through PostMan[5].
First, we can see the running server on the localhost and port 8000:
When we access to the endpoint /
, we can see the welcome message as follows:
Now, I’m going to interact with the model.
Using the user ID “devjwsong”, I chatted with the model:
We can see that the model returns a corresponding response given a POST /infer
request with user_id
and message
.
Also, I printed the chat log in the inferencer for debug:
In this post, we’ve looked through how to make a chat API based on a fine-tuned GPT-2 model with FastAPI.
While it is super simple and far from a large-scaled application where a wide range of functionalities, batch processing, and inference speed should be considered, I hope you can get a basic understanding of serving an ML model as an API.
The plan for the next post would be a paper review which I’ve recently submitted.
Thank you.