
committed by
GitHub

1 changed files with 115 additions and 0 deletions
@ -0,0 +1,115 @@ |
|||
# coding=utf-8 |
|||
# Chat with LLaMA in API mode. |
|||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint |
|||
|
|||
# Call: |
|||
# curl --location 'http://127.0.0.1:8000' \ |
|||
# --header 'Content-Type: application/json' \ |
|||
# --data '{"prompt": "Hello there!","history": []}' |
|||
|
|||
# Response: |
|||
# { |
|||
# "response":"'I am a second year student at the University of British Columbia, in Vancouver.\\nMy major |
|||
# is Computer Science and my minor (double degree) area was Mathematics/Statistics with an emphasis on Operations |
|||
# Research & Management Sciences which means that when it comes to solving problems using computers or any kind data |
|||
# analysis; whether its from businesses , governments etc., i can help you out :) .'", |
|||
# "history":"[('Hello there!', |
|||
# 'I am a second year student at the University of British Columbia, in Vancouver.\\nMy major is Computer Science and |
|||
# my minor (double degree) area was Mathematics/Statistics with an emphasis on Operations Research & Management |
|||
# Sciences which means that when it comes to solving problems using computers or any kind data analysis; whether its |
|||
# from businesses , governments etc., i can help you out :) .')]", |
|||
# "status":200, |
|||
# "time":"2023-05-30 06:57:38" } |
|||
|
|||
import datetime |
|||
import torch |
|||
from utils import ModelArguments, auto_configure_device_map, load_pretrained |
|||
from transformers import HfArgumentParser |
|||
import json |
|||
import uvicorn |
|||
from fastapi import FastAPI, Request |
|||
|
|||
DEVICE = "cuda" |
|||
|
|||
|
|||
def torch_gc(): |
|||
if torch.cuda.is_available(): |
|||
num_gpus = torch.cuda.device_count() |
|||
for device_id in range(num_gpus): |
|||
with torch.cuda.device(device_id): |
|||
torch.cuda.empty_cache() |
|||
torch.cuda.ipc_collect() |
|||
|
|||
|
|||
app = FastAPI() |
|||
|
|||
|
|||
@app.post("/") |
|||
async def create_item(request: Request): |
|||
global model, tokenizer |
|||
|
|||
# Parse the request JSON |
|||
json_post_raw = await request.json() |
|||
json_post = json.dumps(json_post_raw) |
|||
json_post_list = json.loads(json_post) |
|||
prompt = json_post_list.get('prompt') |
|||
history = json_post_list.get('history') |
|||
|
|||
# Tokenize the input prompt |
|||
inputs = tokenizer([prompt], return_tensors="pt") |
|||
inputs = inputs.to(model.device) |
|||
|
|||
# Generation arguments |
|||
gen_kwargs = { |
|||
"do_sample": True, |
|||
"top_p": 0.9, |
|||
"top_k": 40, |
|||
"temperature": 0.7, |
|||
"num_beams": 1, |
|||
"max_new_tokens": 256, |
|||
"repetition_penalty": 1.5 |
|||
} |
|||
|
|||
# Generate response |
|||
with torch.no_grad(): |
|||
generation_output = model.generate(**inputs, **gen_kwargs) |
|||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] |
|||
response = tokenizer.decode(outputs, skip_special_tokens=True) |
|||
|
|||
# Update history |
|||
history = history + [(prompt, response)] |
|||
|
|||
# Prepare response |
|||
now = datetime.datetime.now() |
|||
time = now.strftime("%Y-%m-%d %H:%M:%S") |
|||
answer = { |
|||
"response": repr(response), |
|||
"history": repr(history), |
|||
"status": 200, |
|||
"time": time |
|||
} |
|||
|
|||
# Log and clean up |
|||
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' |
|||
print(log) |
|||
torch_gc() |
|||
|
|||
return answer |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
parser = HfArgumentParser(ModelArguments) |
|||
model_args, = parser.parse_args_into_dataclasses() |
|||
model, tokenizer = load_pretrained(model_args) |
|||
|
|||
if torch.cuda.device_count() > 1: |
|||
from accelerate import dispatch_model |
|||
|
|||
device_map = auto_configure_device_map(torch.cuda.device_count()) |
|||
model = dispatch_model(model, device_map) |
|||
else: |
|||
model = model.cuda() |
|||
|
|||
model.eval() |
|||
|
|||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) |
Loading…
Reference in new issue