From 3d8d5ee5d54102dd73856fac3a80922ea3104a06 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 5 Jun 2023 21:32:18 +0800 Subject: [PATCH] add API demo from #1 --- src/api.py | 115 --------------------------------------------------- src/api_demo.py | 118 +++++++++++++++++++++++++++++++++++++++++++++++++++++ src/cli_demo.py | 2 +- src/utils/other.py | 2 +- src/web_demo.py | 90 ++++++++++++++++++++++------------------ 5 files changed, 170 insertions(+), 157 deletions(-) delete mode 100644 src/api.py create mode 100644 src/api_demo.py diff --git a/src/api.py b/src/api.py deleted file mode 100644 index ad8fda2..0000000 --- a/src/api.py +++ /dev/null @@ -1,115 +0,0 @@ -# coding=utf-8 -# Chat with LLaMA in API mode. -# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint - -# Call: -# curl --location 'http://127.0.0.1:8000' \ -# --header 'Content-Type: application/json' \ -# --data '{"prompt": "Hello there!","history": []}' - -# Response: -# { -# "response":"'I am a second year student at the University of British Columbia, in Vancouver.\\nMy major -# is Computer Science and my minor (double degree) area was Mathematics/Statistics with an emphasis on Operations -# Research & Management Sciences which means that when it comes to solving problems using computers or any kind data -# analysis; whether its from businesses , governments etc., i can help you out :) .'", -# "history":"[('Hello there!', -# 'I am a second year student at the University of British Columbia, in Vancouver.\\nMy major is Computer Science and -# my minor (double degree) area was Mathematics/Statistics with an emphasis on Operations Research & Management -# Sciences which means that when it comes to solving problems using computers or any kind data analysis; whether its -# from businesses , governments etc., i can help you out :) .')]", -# "status":200, -# "time":"2023-05-30 06:57:38" } - -import datetime -import torch -from utils import ModelArguments, auto_configure_device_map, load_pretrained -from transformers import HfArgumentParser -import json -import uvicorn -from fastapi import FastAPI, Request - -DEVICE = "cuda" - - -def torch_gc(): - if torch.cuda.is_available(): - num_gpus = torch.cuda.device_count() - for device_id in range(num_gpus): - with torch.cuda.device(device_id): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - - -app = FastAPI() - - -@app.post("/") -async def create_item(request: Request): - global model, tokenizer - - # Parse the request JSON - json_post_raw = await request.json() - json_post = json.dumps(json_post_raw) - json_post_list = json.loads(json_post) - prompt = json_post_list.get('prompt') - history = json_post_list.get('history') - - # Tokenize the input prompt - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(model.device) - - # Generation arguments - gen_kwargs = { - "do_sample": True, - "top_p": 0.9, - "top_k": 40, - "temperature": 0.7, - "num_beams": 1, - "max_new_tokens": 256, - "repetition_penalty": 1.5 - } - - # Generate response - with torch.no_grad(): - generation_output = model.generate(**inputs, **gen_kwargs) - outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs, skip_special_tokens=True) - - # Update history - history = history + [(prompt, response)] - - # Prepare response - now = datetime.datetime.now() - time = now.strftime("%Y-%m-%d %H:%M:%S") - answer = { - "response": repr(response), - "history": repr(history), - "status": 200, - "time": time - } - - # Log and clean up - log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' - print(log) - torch_gc() - - return answer - - -if __name__ == "__main__": - parser = HfArgumentParser(ModelArguments) - model_args, = parser.parse_args_into_dataclasses() - model, tokenizer = load_pretrained(model_args) - - if torch.cuda.device_count() > 1: - from accelerate import dispatch_model - - device_map = auto_configure_device_map(torch.cuda.device_count()) - model = dispatch_model(model, device_map) - else: - model = model.cuda() - - model.eval() - - uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/src/api_demo.py b/src/api_demo.py new file mode 100644 index 0000000..ca5e05d --- /dev/null +++ b/src/api_demo.py @@ -0,0 +1,118 @@ +# coding=utf-8 +# Implements API for fine-tuned models. +# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint + +# Request: +# curl http://127.0.0.1:8000 --header 'Content-Type: application/json' --data '{"prompt": "Hello there!", "history": []}' + +# Response: +# { +# "response": "'Hi there!'", +# "history": "[('Hello there!', 'Hi there!')]", +# "status": 200, +# "time": "2000-00-00 00:00:00" +# } + + +import json +import torch +import uvicorn +import datetime +from fastapi import FastAPI, Request + +from utils import ( + load_pretrained, + prepare_infer_args, + get_logits_processor +) + + +def torch_gc(): + if not torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + for device_id in range(num_gpus): + with torch.cuda.device(device_id): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI() + + +@app.post("/") +async def create_item(request: Request): + global model, tokenizer, format_example + + # Parse the request JSON + json_post_raw = await request.json() + json_post = json.dumps(json_post_raw) + json_post_list = json.loads(json_post) + prompt = json_post_list.get("prompt") + history = json_post_list.get("history") + + # Tokenize the input prompt + input_ids = tokenizer([format_example(prompt, history)], return_tensors="pt")["input_ids"] + input_ids = input_ids.to(model.device) + + # Generation arguments + gen_kwargs = { + "input_ids": input_ids, + "do_sample": True, + "top_p": 0.7, + "temperature": 0.95, + "num_beams": 1, + "max_new_tokens": 512, + "repetition_penalty": 1.0, + "logits_processor": get_logits_processor() + } + + # Generate response + with torch.no_grad(): + generation_output = model.generate(**gen_kwargs) + outputs = generation_output.tolist()[0][len(input_ids[0]):] + response = tokenizer.decode(outputs, skip_special_tokens=True) + + # Update history + history = history + [(prompt, response)] + + # Prepare response + now = datetime.datetime.now() + time = now.strftime("%Y-%m-%d %H:%M:%S") + answer = { + "response": repr(response), + "history": repr(history), + "status": 200, + "time": time + } + + # Log and clean up + log = "[" + time + "] " + "\", prompt:\"" + prompt + "\", response:\"" + repr(response) + "\"" + print(log) + torch_gc() + + return answer + + +if __name__ == "__main__": + model_args, data_args, finetuning_args = prepare_infer_args() + model, tokenizer = load_pretrained(model_args, finetuning_args) + + def format_example_alpaca(query, history): + prompt = "Below is an instruction that describes a task. " + prompt += "Write a response that appropriately completes the request.\n" + prompt += "Instruction:\n" + for old_query, response in history: + prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) + prompt += "Human: {}\nAssistant:".format(query) + return prompt + + def format_example_ziya(query, history): + prompt = "" + for old_query, response in history: + prompt += ": {}\n: {}\n".format(old_query, response) + prompt += ": {}\n:".format(query) + return prompt + + format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya + + uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/src/cli_demo.py b/src/cli_demo.py index e8e4fb3..9aae2f7 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -1,6 +1,6 @@ # coding=utf-8 # Implements stream chat in command line for fine-tuned models. -# Usage: python cli_demo.py --checkpoint_dir path_to_checkpoint +# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint from utils import ( diff --git a/src/utils/other.py b/src/utils/other.py index 4e2d561..77ef30b 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -142,7 +142,7 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) - def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]: - """ + r""" EMA implementation according to TensorBoard. """ last = scalars[0] diff --git a/src/web_demo.py b/src/web_demo.py index 96bfbcd..54bf634 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -1,6 +1,6 @@ # coding=utf-8 # Implements user interface in browser for fine-tuned models. -# Usage: python web_demo.py --checkpoint_dir path_to_checkpoint +# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint import mdtex2html @@ -13,13 +13,38 @@ from transformers.utils.versions import require_version require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0") + + model_args, data_args, finetuning_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) -"""Override Chatbot.postprocess""" +def format_example_alpaca(query, history): + prompt = "Below is an instruction that describes a task. " + prompt += "Write a response that appropriately completes the request.\n" + prompt += "Instruction:\n" + for old_query, response in history: + prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) + prompt += "Human: {}\nAssistant:".format(query) + return prompt + + +def format_example_ziya(query, history): + prompt = "" + for old_query, response in history: + prompt += ": {}\n: {}\n".format(old_query, response) + prompt += ": {}\n:".format(query) + return prompt + + +format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya +streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) + def postprocess(self, y): + r""" + Overrides Chatbot.postprocess + """ if y is None: return [] for i, (message, response) in enumerate(y): @@ -40,11 +65,11 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT for i, line in enumerate(lines): if "```" in line: count += 1 - items = line.split('`') + items = line.split("`") if count % 2 == 1: - lines[i] = f'
'
+                lines[i] = "
".format(items[-1])
             else:
-                lines[i] = f'
' + lines[i] = "
" else: if i > 0: if count % 2 == 1: @@ -60,37 +85,15 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") - lines[i] = "
"+line + lines[i] = "
" + line text = "".join(lines) return text -def format_example_alpaca(query, history): - prompt = "Below is an instruction that describes a task. " - prompt += "Write a response that appropriately completes the request.\n" - prompt += "Instruction:\n" - for old_query, response in history: - prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) - prompt += "Human: {}\nAssistant:".format(query) - return prompt - - -def format_example_ziya(query, history): - prompt = "" - for old_query, response in history: - prompt += ": {}\n: {}\n".format(old_query, response) - prompt += ": {}\n:".format(query) - return prompt - - -format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya -streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) - - -def predict(input, chatbot, max_length, top_p, temperature, history): - chatbot.append((parse_text(input), "")) +def predict(query, chatbot, max_length, top_p, temperature, history): + chatbot.append((parse_text(query), "")) - input_ids = tokenizer([format_example(input, history)], return_tensors="pt")["input_ids"] + input_ids = tokenizer([format_example(query, history)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) gen_kwargs = { "input_ids": input_ids, @@ -108,13 +111,13 @@ def predict(input, chatbot, max_length, top_p, temperature, history): response = "" for new_text in streamer: response += new_text - new_history = history + [(input, response)] - chatbot[-1] = (parse_text(input), parse_text(response)) + new_history = history + [(query, response)] + chatbot[-1] = (parse_text(query), parse_text(response)) yield chatbot, new_history def reset_user_input(): - return gr.update(value='') + return gr.update(value="") def reset_state(): @@ -122,26 +125,33 @@ def reset_state(): with gr.Blocks() as demo: - gr.HTML("""

LLaMA-Efficient-Tuning

""") + + gr.HTML(""" +

+ + LLaMA Efficient Tuning + +

+ """) chatbot = gr.Chatbot() + with gr.Row(): with gr.Column(scale=4): with gr.Column(scale=12): - user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( - container=False) + user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False) with gr.Column(min_width=32, scale=1): submitBtn = gr.Button("Submit", variant="primary") + with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) - temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + temperature = gr.Slider(0, 1.5, value=0.95, step=0.01, label="Temperature", interactive=True) history = gr.State([]) - submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], - show_progress=True) + submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True) submitBtn.click(reset_user_input, [], [user_input]) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)