From e82168243060ba77ff458a5b1e8479468b8e7265 Mon Sep 17 00:00:00 2001 From: mMrBun <2015711377@qq.com> Date: Tue, 30 May 2023 14:46:22 +0800 Subject: [PATCH 1/2] Support conversation via API. --- src/api.py | 111 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 src/api.py diff --git a/src/api.py b/src/api.py new file mode 100644 index 0000000..da6234d --- /dev/null +++ b/src/api.py @@ -0,0 +1,111 @@ +# coding=utf-8 +# Chat with LLaMA in API mode. +# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint + +# Call: curl --location 'http://127.0.0.1:8000' --header 'Content-Type: application/json' --data '{ +# "prompt": "你好", +# "history": [] +# }' + +# Response: +# { +# "response": "'!我是很高兴的,因为您在这里访问了钱学英语网站。\\n请使用下面所给内容来完成注册流程:-首先选定要申应什么类型(基本或加上一门外教) -然后输入个人信息 (如真实名字、电话号码等) " +# "–确认收到回复消息就可以开始查看视力和通过测试取得自由之地进行经常更新中文版'", +# "history": "[('你好', '!我是很高兴的,因为您在这里访问了钱学英语网站。\\n请使用下面所给内容来完成注册流程:-首先选定要申应什么类型(基本或加上一门外教) -然后输入个人信息 (如真实名字、电话号码等) " +# "–确认收到回复消息就可以开始查看视力和通过测试取得自由之地进行经常更新中文版')]", +# "status": 200, +# "time": "2023-05-30 06:33:16" +# } + +import datetime +import torch +from utils import ModelArguments, auto_configure_device_map, load_pretrained +from transformers import HfArgumentParser +import json +import uvicorn +from fastapi import FastAPI, Request + +DEVICE = "cuda" + + +def torch_gc(): + if torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + for device_id in range(num_gpus): + with torch.cuda.device(device_id): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI() + + +@app.post("/") +async def create_item(request: Request): + global model, tokenizer + + # Parse the request JSON + json_post_raw = await request.json() + json_post = json.dumps(json_post_raw) + json_post_list = json.loads(json_post) + prompt = json_post_list.get('prompt') + history = json_post_list.get('history') + + # Tokenize the input prompt + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(model.device) + + # Generation arguments + gen_kwargs = { + "do_sample": True, + "top_p": 0.9, + "top_k": 40, + "temperature": 0.7, + "num_beams": 1, + "max_new_tokens": 256, + "repetition_penalty": 1.5 + } + + # Generate response + with torch.no_grad(): + generation_output = model.generate(**inputs, **gen_kwargs) + outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] + response = tokenizer.decode(outputs, skip_special_tokens=True) + + # Update history + history = history + [(prompt, response)] + + # Prepare response + now = datetime.datetime.now() + time = now.strftime("%Y-%m-%d %H:%M:%S") + answer = { + "response": repr(response), + "history": repr(history), + "status": 200, + "time": time + } + + # Log and clean up + log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' + print(log) + torch_gc() + + return answer + + +if __name__ == "__main__": + parser = HfArgumentParser(ModelArguments) + model_args, = parser.parse_args_into_dataclasses() + model, tokenizer = load_pretrained(model_args) + + if torch.cuda.device_count() > 1: + from accelerate import dispatch_model + + device_map = auto_configure_device_map(torch.cuda.device_count()) + model = dispatch_model(model, device_map) + else: + model = model.cuda() + + model.eval() + + uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) From 748b804bac8e2c972246bc7f7d50884fe105a7fe Mon Sep 17 00:00:00 2001 From: mMrBun <2015711377@qq.com> Date: Tue, 30 May 2023 15:00:28 +0800 Subject: [PATCH 2/2] Support conversation via API. --- src/api.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/api.py b/src/api.py index da6234d..ad8fda2 100644 --- a/src/api.py +++ b/src/api.py @@ -2,20 +2,24 @@ # Chat with LLaMA in API mode. # Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint -# Call: curl --location 'http://127.0.0.1:8000' --header 'Content-Type: application/json' --data '{ -# "prompt": "你好", -# "history": [] -# }' +# Call: +# curl --location 'http://127.0.0.1:8000' \ +# --header 'Content-Type: application/json' \ +# --data '{"prompt": "Hello there!","history": []}' # Response: # { -# "response": "'!我是很高兴的,因为您在这里访问了钱学英语网站。\\n请使用下面所给内容来完成注册流程:-首先选定要申应什么类型(基本或加上一门外教) -然后输入个人信息 (如真实名字、电话号码等) " -# "–确认收到回复消息就可以开始查看视力和通过测试取得自由之地进行经常更新中文版'", -# "history": "[('你好', '!我是很高兴的,因为您在这里访问了钱学英语网站。\\n请使用下面所给内容来完成注册流程:-首先选定要申应什么类型(基本或加上一门外教) -然后输入个人信息 (如真实名字、电话号码等) " -# "–确认收到回复消息就可以开始查看视力和通过测试取得自由之地进行经常更新中文版')]", -# "status": 200, -# "time": "2023-05-30 06:33:16" -# } +# "response":"'I am a second year student at the University of British Columbia, in Vancouver.\\nMy major +# is Computer Science and my minor (double degree) area was Mathematics/Statistics with an emphasis on Operations +# Research & Management Sciences which means that when it comes to solving problems using computers or any kind data +# analysis; whether its from businesses , governments etc., i can help you out :) .'", +# "history":"[('Hello there!', +# 'I am a second year student at the University of British Columbia, in Vancouver.\\nMy major is Computer Science and +# my minor (double degree) area was Mathematics/Statistics with an emphasis on Operations Research & Management +# Sciences which means that when it comes to solving problems using computers or any kind data analysis; whether its +# from businesses , governments etc., i can help you out :) .')]", +# "status":200, +# "time":"2023-05-30 06:57:38" } import datetime import torch