
5 changed files with 170 additions and 157 deletions
@ -1,115 +0,0 @@ |
|||||
# coding=utf-8 |
|
||||
# Chat with LLaMA in API mode. |
|
||||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint |
|
||||
|
|
||||
# Call: |
|
||||
# curl --location 'http://127.0.0.1:8000' \ |
|
||||
# --header 'Content-Type: application/json' \ |
|
||||
# --data '{"prompt": "Hello there!","history": []}' |
|
||||
|
|
||||
# Response: |
|
||||
# { |
|
||||
# "response":"'I am a second year student at the University of British Columbia, in Vancouver.\\nMy major |
|
||||
# is Computer Science and my minor (double degree) area was Mathematics/Statistics with an emphasis on Operations |
|
||||
# Research & Management Sciences which means that when it comes to solving problems using computers or any kind data |
|
||||
# analysis; whether its from businesses , governments etc., i can help you out :) .'", |
|
||||
# "history":"[('Hello there!', |
|
||||
# 'I am a second year student at the University of British Columbia, in Vancouver.\\nMy major is Computer Science and |
|
||||
# my minor (double degree) area was Mathematics/Statistics with an emphasis on Operations Research & Management |
|
||||
# Sciences which means that when it comes to solving problems using computers or any kind data analysis; whether its |
|
||||
# from businesses , governments etc., i can help you out :) .')]", |
|
||||
# "status":200, |
|
||||
# "time":"2023-05-30 06:57:38" } |
|
||||
|
|
||||
import datetime |
|
||||
import torch |
|
||||
from utils import ModelArguments, auto_configure_device_map, load_pretrained |
|
||||
from transformers import HfArgumentParser |
|
||||
import json |
|
||||
import uvicorn |
|
||||
from fastapi import FastAPI, Request |
|
||||
|
|
||||
DEVICE = "cuda" |
|
||||
|
|
||||
|
|
||||
def torch_gc(): |
|
||||
if torch.cuda.is_available(): |
|
||||
num_gpus = torch.cuda.device_count() |
|
||||
for device_id in range(num_gpus): |
|
||||
with torch.cuda.device(device_id): |
|
||||
torch.cuda.empty_cache() |
|
||||
torch.cuda.ipc_collect() |
|
||||
|
|
||||
|
|
||||
app = FastAPI() |
|
||||
|
|
||||
|
|
||||
@app.post("/") |
|
||||
async def create_item(request: Request): |
|
||||
global model, tokenizer |
|
||||
|
|
||||
# Parse the request JSON |
|
||||
json_post_raw = await request.json() |
|
||||
json_post = json.dumps(json_post_raw) |
|
||||
json_post_list = json.loads(json_post) |
|
||||
prompt = json_post_list.get('prompt') |
|
||||
history = json_post_list.get('history') |
|
||||
|
|
||||
# Tokenize the input prompt |
|
||||
inputs = tokenizer([prompt], return_tensors="pt") |
|
||||
inputs = inputs.to(model.device) |
|
||||
|
|
||||
# Generation arguments |
|
||||
gen_kwargs = { |
|
||||
"do_sample": True, |
|
||||
"top_p": 0.9, |
|
||||
"top_k": 40, |
|
||||
"temperature": 0.7, |
|
||||
"num_beams": 1, |
|
||||
"max_new_tokens": 256, |
|
||||
"repetition_penalty": 1.5 |
|
||||
} |
|
||||
|
|
||||
# Generate response |
|
||||
with torch.no_grad(): |
|
||||
generation_output = model.generate(**inputs, **gen_kwargs) |
|
||||
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] |
|
||||
response = tokenizer.decode(outputs, skip_special_tokens=True) |
|
||||
|
|
||||
# Update history |
|
||||
history = history + [(prompt, response)] |
|
||||
|
|
||||
# Prepare response |
|
||||
now = datetime.datetime.now() |
|
||||
time = now.strftime("%Y-%m-%d %H:%M:%S") |
|
||||
answer = { |
|
||||
"response": repr(response), |
|
||||
"history": repr(history), |
|
||||
"status": 200, |
|
||||
"time": time |
|
||||
} |
|
||||
|
|
||||
# Log and clean up |
|
||||
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' |
|
||||
print(log) |
|
||||
torch_gc() |
|
||||
|
|
||||
return answer |
|
||||
|
|
||||
|
|
||||
if __name__ == "__main__": |
|
||||
parser = HfArgumentParser(ModelArguments) |
|
||||
model_args, = parser.parse_args_into_dataclasses() |
|
||||
model, tokenizer = load_pretrained(model_args) |
|
||||
|
|
||||
if torch.cuda.device_count() > 1: |
|
||||
from accelerate import dispatch_model |
|
||||
|
|
||||
device_map = auto_configure_device_map(torch.cuda.device_count()) |
|
||||
model = dispatch_model(model, device_map) |
|
||||
else: |
|
||||
model = model.cuda() |
|
||||
|
|
||||
model.eval() |
|
||||
|
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) |
|
@ -0,0 +1,118 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Implements API for fine-tuned models. |
||||
|
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint |
||||
|
|
||||
|
# Request: |
||||
|
# curl http://127.0.0.1:8000 --header 'Content-Type: application/json' --data '{"prompt": "Hello there!", "history": []}' |
||||
|
|
||||
|
# Response: |
||||
|
# { |
||||
|
# "response": "'Hi there!'", |
||||
|
# "history": "[('Hello there!', 'Hi there!')]", |
||||
|
# "status": 200, |
||||
|
# "time": "2000-00-00 00:00:00" |
||||
|
# } |
||||
|
|
||||
|
|
||||
|
import json |
||||
|
import torch |
||||
|
import uvicorn |
||||
|
import datetime |
||||
|
from fastapi import FastAPI, Request |
||||
|
|
||||
|
from utils import ( |
||||
|
load_pretrained, |
||||
|
prepare_infer_args, |
||||
|
get_logits_processor |
||||
|
) |
||||
|
|
||||
|
|
||||
|
def torch_gc(): |
||||
|
if not torch.cuda.is_available(): |
||||
|
num_gpus = torch.cuda.device_count() |
||||
|
for device_id in range(num_gpus): |
||||
|
with torch.cuda.device(device_id): |
||||
|
torch.cuda.empty_cache() |
||||
|
torch.cuda.ipc_collect() |
||||
|
|
||||
|
|
||||
|
app = FastAPI() |
||||
|
|
||||
|
|
||||
|
@app.post("/") |
||||
|
async def create_item(request: Request): |
||||
|
global model, tokenizer, format_example |
||||
|
|
||||
|
# Parse the request JSON |
||||
|
json_post_raw = await request.json() |
||||
|
json_post = json.dumps(json_post_raw) |
||||
|
json_post_list = json.loads(json_post) |
||||
|
prompt = json_post_list.get("prompt") |
||||
|
history = json_post_list.get("history") |
||||
|
|
||||
|
# Tokenize the input prompt |
||||
|
input_ids = tokenizer([format_example(prompt, history)], return_tensors="pt")["input_ids"] |
||||
|
input_ids = input_ids.to(model.device) |
||||
|
|
||||
|
# Generation arguments |
||||
|
gen_kwargs = { |
||||
|
"input_ids": input_ids, |
||||
|
"do_sample": True, |
||||
|
"top_p": 0.7, |
||||
|
"temperature": 0.95, |
||||
|
"num_beams": 1, |
||||
|
"max_new_tokens": 512, |
||||
|
"repetition_penalty": 1.0, |
||||
|
"logits_processor": get_logits_processor() |
||||
|
} |
||||
|
|
||||
|
# Generate response |
||||
|
with torch.no_grad(): |
||||
|
generation_output = model.generate(**gen_kwargs) |
||||
|
outputs = generation_output.tolist()[0][len(input_ids[0]):] |
||||
|
response = tokenizer.decode(outputs, skip_special_tokens=True) |
||||
|
|
||||
|
# Update history |
||||
|
history = history + [(prompt, response)] |
||||
|
|
||||
|
# Prepare response |
||||
|
now = datetime.datetime.now() |
||||
|
time = now.strftime("%Y-%m-%d %H:%M:%S") |
||||
|
answer = { |
||||
|
"response": repr(response), |
||||
|
"history": repr(history), |
||||
|
"status": 200, |
||||
|
"time": time |
||||
|
} |
||||
|
|
||||
|
# Log and clean up |
||||
|
log = "[" + time + "] " + "\", prompt:\"" + prompt + "\", response:\"" + repr(response) + "\"" |
||||
|
print(log) |
||||
|
torch_gc() |
||||
|
|
||||
|
return answer |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
model_args, data_args, finetuning_args = prepare_infer_args() |
||||
|
model, tokenizer = load_pretrained(model_args, finetuning_args) |
||||
|
|
||||
|
def format_example_alpaca(query, history): |
||||
|
prompt = "Below is an instruction that describes a task. " |
||||
|
prompt += "Write a response that appropriately completes the request.\n" |
||||
|
prompt += "Instruction:\n" |
||||
|
for old_query, response in history: |
||||
|
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) |
||||
|
prompt += "Human: {}\nAssistant:".format(query) |
||||
|
return prompt |
||||
|
|
||||
|
def format_example_ziya(query, history): |
||||
|
prompt = "" |
||||
|
for old_query, response in history: |
||||
|
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response) |
||||
|
prompt += "<human>: {}\n<bot>:".format(query) |
||||
|
return prompt |
||||
|
|
||||
|
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya |
||||
|
|
||||
|
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) |
Loading…
Reference in new issue