Jaewoo Song
Jaewoo Song

Categories

  • Tech

It has been a long time since I posted a tech post!

I’ve been so busy pursuing my graduate studies, job hunting, and Ph.D. applications, so I haven’t had any chance to invest sufficient time to post a technical post.

Since my last semester has ended and I have a bit of time before starting a new chapter, I want to make a simple post on how to serve a deep neural network into an API using the FastAPI[1] library.

Let’s go through it.



Implementation of the model prediction

First, we should choose which model should be served.

In this post, I’m going to use a fine-tuned GPT-2 model, which I already introduced in the previous post because it is easy to use, intuitive, and not too large to be loaded to my local GPU.

For ease of usage, I uploaded the trained model to my Hugging Face Hub repository.

You can see the model in this link.

The screenshot of the repository page.


Same as other pre-trained models in the Hub, we need only a few lines to load the tokenizer and fine-tuned model using .from_pretrained function supported by Transformers library.

Still, the text input should be preprocessed in the same format used for fine-tuning.

Let’s take each step for implementing the inferencing logic by refactoring the original code here[2].

from itertools import chain
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

class Inferencer():
    def __init__(self, args):
        self.chat_history = {}

        # Setting the GPU.
        if torch.cuda.is_available() and isinstance(args.gpu, int):
            self.device = torch.device(f"cuda:{args.gpu}")
        else:
            self.device = torch.device("cpu")

        # Setting the tokenizer.
        self.tokenizer = GPT2Tokenizer.from_pretrained(args.model_path)
        special_tokens = self.tokenizer.special_tokens_map
        self.bos_token = special_tokens['bos_token']
        self.eos_token = special_tokens['eos_token']
        self.sp1_token = special_tokens['additional_special_tokens'][0]
        self.sp2_token = special_tokens['additional_special_tokens'][1]

        vocab = self.tokenizer.get_vocab()
        self.vocab_size = len(vocab)
        self.bos_id = vocab[self.bos_token]
        self.eos_id = vocab[self.eos_token]
        self.sp1_id = vocab[self.sp1_token]
        self.sp2_id = vocab[self.sp2_token]

        # Setting the model.
        self.model = GPT2LMHeadModel.from_pretrained(args.model_path).to(self.device)
        self.model.eval()

        # Decoding parameters.
        self.max_turns = args.max_turns
        self.max_len = self.model.config.n_ctx
        self.top_p = args.top_p
        self.temperature = args.temperature
        self.num_beams = args.num_beams
        
	...


The constructor is pretty straightforward, setting the GPU, model, tokenizer, special tokens we used for fine-tuning, and the hyperparameters for decoding.

All member attributes in args are set freely.

The details of the special tokens are elaborated in the previous post on fine-tuning.

Also, while I actually implemented the nucleus sampling[3] from scratch before, this time we are going to use the generation function by Transformers library.

So I just defined a few more parameters other than top_p, such as temperature and num_beams.


Also, I defined self.chat_history dictionary to memorize the chat history of each user.

Since this inferencer will be shared, the records should be entirely separated between the users.

Of course, this implementation design can be extended, for example, using a multi-threaded server, each thread takes one user connection and it has its own inferencer object which cannot be compromised by other users.

However, for simplicity, we are going to use just one shared inferencer in this post.

To add a new message from a user, we can have add_history function as below.

# Adding a new message.
def add_message(self, user_id, speaker_id, message):
    if user_id not in self.chat_history:
        self.chat_history[user_id] = []

    self.chat_history[user_id].append((speaker_id, message))


Finally, I implemented the prediction function to process each input message and generate a new response from the model:

# A single prediction.
async def predict(self, user_id):
    input_hists = []
    for tup in self.chat_history[user_id]:
        token_ids = [self.sp1_id if tup[0] == 1 else self.sp2_id] + self.tokenizer.encode(tup[1])
        input_hists.append(token_ids)

    # Adjusting the length.
    if len(input_hists) >= self.max_turns:
        num_exceeded = len(input_hists) - self.max_turns + 1
        input_hists = input_hists[num_exceeded:]

    # Setting the input ids and type ids.
    input_ids = [self.bos_id] + list(chain.from_iterable(input_hists)) + [self.sp2_id]
    start_sp_id = input_hists[0][0]
    next_sp_id = self.sp1_id if start_sp_id == self.sp2_id else self.sp2_id
    token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(input_hists)]
    token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [self.sp2_id]
    input_len = len(input_ids)

    input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)
    token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device)

    # Getting the output.
    output_ids = self.model.generate(
        input_ids=input_ids, 
        token_type_ids=token_type_ids, 
        pad_token_id=self.eos_id,
        do_sample=True, 
        top_p=self.top_p, 
        max_length=self.max_len, 
        num_beams=self.num_beams, 
        temperature=self.temperature,
        output_hidden_states=True, 
        output_scores=True, 
        return_dict_in_generate=True
    ).sequences
    output_ids = output_ids[0].tolist()[input_len:]
    res = self.tokenizer.decode(output_ids, skip_special_tokens=True)
    self.chat_history[user_id].append((2, res))

    return res


As mentioned before, this is a refactored version of the original repository.

The difference here is that we use model.generate function and we process the messages in the whole record for every prediction, not processing beforehand.



Integration with FastAPI

This is the main part, which we implement the FastAPI endpoints for the actual usage.

First, we define the data format using Pydantic[4] to parse the correct types of data.

from pydantic import BaseModel

# Definition of a basic text input.
class TextInput(BaseModel):
    user_id: str
    message: str


Here, I defined user_id which is a unique string per each different user and message which is the actual text message a user typed.

Using this model, we can easily validate the data given in the body of the POST request.


Now, we are going to add the FastAPI object and a few endpoints to process HTTP requests.

from fastapi import FastAPI

app = FastAPI()

# Default page.
@app.get("/")
def index():
    return {'message': "Welcome to the basic GPT2 chit chat API!"}


# Posting one user message.
@app.post("/infer")
async def infer(data: TextInput):
    data = data.dict()

    user_id = data['user_id']
    message = data['message']

    inferencer.add_message(user_id, 1, message)
    response = await inferencer.predict(user_id)

    return {'message': response}


# Running the server.
uvicorn.run(app, host='127.0.0.1', port=args.port)


First, I made a FastAPI object called app.

Using the @app decorator, we can add a function to process an HTTP request and return the corresponding result.

Here, @app.get("/") takes the default endpoint and just shows a welcoming message.

@app.post("/infer") is the main function to process one interaction between the user and the model.

When a POST request with the endpoint /infer arrives, this function parses the data to get user_id and message from the request body, adds a new user message in the chat history initialized in the inferencer, and receives a response predicted from the model.


In addition, I defined infer as an asynchronous to have await for each prediction, since one generation might take a little bit long considering the communication between the RAM and GPU memory.

While it does not matter in a simple example like this, in an actual Web application, it is important to consider the asynchronous processing when the actual prediction takes long enough to affect other functionalities.

One advantage of FastAPI is that we don’t need an additional library such as asyncio to use an asynchronous function since FastAPI does it for us.


Ultimately, the full file serve_gpt2_fastapi.py looks like:

from itertools import chain
from pydantic import BaseModel
from fastapi import FastAPI
from transformers import GPT2Tokenizer, GPT2LMHeadModel

import uvicorn
import torch
import argparse


# Definition of a basic text input.
class TextInput(BaseModel):
    user_id: str
    message: str


# Definition of the main inferencer class.
class Inferencer():
    def __init__(self, args):
        self.chat_history = {}

        # Setting the GPU.
        if torch.cuda.is_available() and isinstance(args.gpu, int):
            self.device = torch.device(f"cuda:{args.gpu}")
        else:
            self.device = torch.device("cpu")

        # Setting the tokenizer.
        self.tokenizer = GPT2Tokenizer.from_pretrained(args.model_path)
        special_tokens = self.tokenizer.special_tokens_map
        self.bos_token = special_tokens['bos_token']
        self.eos_token = special_tokens['eos_token']
        self.sp1_token = special_tokens['additional_special_tokens'][0]
        self.sp2_token = special_tokens['additional_special_tokens'][1]

        vocab = self.tokenizer.get_vocab()
        self.vocab_size = len(vocab)
        self.bos_id = vocab[self.bos_token]
        self.eos_id = vocab[self.eos_token]
        self.sp1_id = vocab[self.sp1_token]
        self.sp2_id = vocab[self.sp2_token]

        # Setting the model.
        self.model = GPT2LMHeadModel.from_pretrained(args.model_path).to(self.device)
        self.model.eval()

        # Decoding parameters.
        self.max_turns = args.max_turns
        self.max_len = self.model.config.n_ctx
        self.top_p = args.top_p
        self.temperature = args.temperature
        self.num_beams = args.num_beams

    # Adding a new message.
    def add_message(self, user_id, speaker_id, message):
        if user_id not in self.chat_history:
            self.chat_history[user_id] = []

        self.chat_history[user_id].append((speaker_id, message))

    # A single prediction.
    async def predict(self, user_id):
        input_hists = []
        for tup in self.chat_history[user_id]:
            token_ids = [self.sp1_id if tup[0] == 1 else self.sp2_id] + self.tokenizer.encode(tup[1])
            input_hists.append(token_ids)

        # Adjusting the length.
        if len(input_hists) >= self.max_turns:
            num_exceeded = len(input_hists) - self.max_turns + 1
            input_hists = input_hists[num_exceeded:]

        # Setting the input ids and type ids.
        input_ids = [self.bos_id] + list(chain.from_iterable(input_hists)) + [self.sp2_id]
        start_sp_id = input_hists[0][0]
        next_sp_id = self.sp1_id if start_sp_id == self.sp2_id else self.sp2_id
        token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(input_hists)]
        token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [self.sp2_id]
        input_len = len(input_ids)

        input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.device)
        token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.device)

        # Getting the output.
        output_ids = self.model.generate(
            input_ids=input_ids, 
            token_type_ids=token_type_ids, 
            pad_token_id=self.eos_id,
            do_sample=True, 
            top_p=self.top_p, 
            max_length=self.max_len, 
            num_beams=self.num_beams, 
            temperature=self.temperature,
            output_hidden_states=True, 
            output_scores=True, 
            return_dict_in_generate=True
        ).sequences
        output_ids = output_ids[0].tolist()[input_len:]
        res = self.tokenizer.decode(output_ids, skip_special_tokens=True)
        self.chat_history[user_id].append((2, res))

        return res
        

app = FastAPI()


# Default page.
@app.get("/")
def index():
    return {'message': "Welcome to the basic GPT2 chit chat API!"}


# Posting one user message.
@app.post("/infer")
async def infer(data: TextInput):
    data = data.dict()

    user_id = data['user_id']
    message = data['message']

    inferencer.add_message(user_id, 1, message)
    response = await inferencer.predict(user_id)

    return {'message': response}


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--port', type=int, default=8000, help="The port number.")
    parser.add_argument('--model_path', type=str, required=True, help="The path to the model in HuggingFace Hub.")
    parser.add_argument('--gpu', type=int, default=0, help="The index of GPU to use.")
    parser.add_argument('--max_turns', type=int, default=5, help="The maximum number of dialogue histories to include.")
    parser.add_argument('--top_p', type=float, default=1.0, help="The top p value for nucleus sampling.")
    parser.add_argument('--temperature', type=float, default=1.0, help="The temperature value.")
    parser.add_argument('--num_beams', type=int, default=1, help="The number of beams for beam search.")
              
    args = parser.parse_args()

    # Initializing the inferencer.
    inferencer = Inferencer(args)

    # Running the server.
    uvicorn.run(app, host='127.0.0.1', port=args.port)



How does it look?

Now, let’s see how it works by testing the implementation through PostMan[5].

First, we can see the running server on the localhost and port 8000:

The screenshot of the terminal.


When we access to the endpoint /, we can see the welcome message as follows:

The screenshot of the GET / request.


Now, I’m going to interact with the model.

Using the user ID “devjwsong”, I chatted with the model:

The example of a simple interaction.


We can see that the model returns a corresponding response given a POST /infer request with user_id and message.

Also, I printed the chat log in the inferencer for debug:

The sample chat log on the terminal.



In this post, we’ve looked through how to make a chat API based on a fine-tuned GPT-2 model with FastAPI.

While it is super simple and far from a large-scaled application where a wide range of functionalities, batch processing, and inference speed should be considered, I hope you can get a basic understanding of serving an ML model as an API.

The plan for the next post would be a paper review which I’ve recently submitted.

Thank you.



[2] devjwsong/gpt2-dialogue-generation-pytorch. https://github.com/devjwsong/gpt2-dialogue-generation-pytorch
[3] Holtzman, A., Buys, J., Du, L., Forbes, M., & Choi, Y. (2019). The curious case of neural text degeneration. arXiv preprint arXiv:1904.09751. https://arxiv.org/abs/1904.09751.