|
@ -42,7 +42,7 @@ app = FastAPI() |
|
|
|
|
|
|
|
|
@app.post("/") |
|
|
@app.post("/") |
|
|
async def create_item(request: Request): |
|
|
async def create_item(request: Request): |
|
|
global model, tokenizer, prompt_template, generating_args |
|
|
global model, tokenizer, prompt_template, source_prefix, generating_args |
|
|
|
|
|
|
|
|
# Parse the request JSON |
|
|
# Parse the request JSON |
|
|
json_post_raw = await request.json() |
|
|
json_post_raw = await request.json() |
|
@ -55,7 +55,7 @@ async def create_item(request: Request): |
|
|
temperature = json_post_list.get("temperature", None) |
|
|
temperature = json_post_list.get("temperature", None) |
|
|
|
|
|
|
|
|
# Tokenize the input prompt |
|
|
# Tokenize the input prompt |
|
|
input_ids = tokenizer([prompt_template.get_prompt(prompt, history)], return_tensors="pt")["input_ids"] |
|
|
input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")["input_ids"] |
|
|
input_ids = input_ids.to(model.device) |
|
|
input_ids = input_ids.to(model.device) |
|
|
|
|
|
|
|
|
# Generation arguments |
|
|
# Generation arguments |
|
@ -94,8 +94,11 @@ async def create_item(request: Request): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
model_args, data_args, finetuning_args, generating_args = prepare_infer_args() |
|
|
model_args, data_args, finetuning_args, generating_args = prepare_infer_args() |
|
|
model, tokenizer = load_pretrained(model_args, finetuning_args) |
|
|
model, tokenizer = load_pretrained(model_args, finetuning_args) |
|
|
|
|
|
|
|
|
prompt_template = Template(data_args.prompt_template) |
|
|
prompt_template = Template(data_args.prompt_template) |
|
|
|
|
|
source_prefix = data_args.source_prefix if data_args.source_prefix else "" |
|
|
|
|
|
|
|
|
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) |
|
|
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) |
|
|