You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					119 lines
				
				3.4 KiB
			
		
		
			
		
	
	
					119 lines
				
				3.4 KiB
			|   
											2 years ago
										 | # coding=utf-8 | ||
|  | # Implements API for fine-tuned models. | ||
|  | # Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint | ||
|  | 
 | ||
|  | # Request: | ||
|  | # curl http://127.0.0.1:8000 --header 'Content-Type: application/json' --data '{"prompt": "Hello there!", "history": []}' | ||
|  | 
 | ||
|  | # Response: | ||
|  | # { | ||
|  | #   "response": "'Hi there!'", | ||
|  | #   "history": "[('Hello there!', 'Hi there!')]", | ||
|  | #   "status": 200, | ||
|  | #   "time": "2000-00-00 00:00:00" | ||
|  | # } | ||
|  | 
 | ||
|  | 
 | ||
|  | import json | ||
|  | import torch | ||
|  | import uvicorn | ||
|  | import datetime | ||
|  | from fastapi import FastAPI, Request | ||
|  | 
 | ||
|  | from utils import ( | ||
|  |     load_pretrained, | ||
|  |     prepare_infer_args, | ||
|  |     get_logits_processor | ||
|  | ) | ||
|  | 
 | ||
|  | 
 | ||
|  | def torch_gc(): | ||
|  |     if not torch.cuda.is_available(): | ||
|  |         num_gpus = torch.cuda.device_count() | ||
|  |         for device_id in range(num_gpus): | ||
|  |             with torch.cuda.device(device_id): | ||
|  |                 torch.cuda.empty_cache() | ||
|  |                 torch.cuda.ipc_collect() | ||
|  | 
 | ||
|  | 
 | ||
|  | app = FastAPI() | ||
|  | 
 | ||
|  | 
 | ||
|  | @app.post("/") | ||
|  | async def create_item(request: Request): | ||
|  |     global model, tokenizer, format_example | ||
|  | 
 | ||
|  |     # Parse the request JSON | ||
|  |     json_post_raw = await request.json() | ||
|  |     json_post = json.dumps(json_post_raw) | ||
|  |     json_post_list = json.loads(json_post) | ||
|  |     prompt = json_post_list.get("prompt") | ||
|  |     history = json_post_list.get("history") | ||
|  | 
 | ||
|  |     # Tokenize the input prompt | ||
|  |     input_ids = tokenizer([format_example(prompt, history)], return_tensors="pt")["input_ids"] | ||
|  |     input_ids = input_ids.to(model.device) | ||
|  | 
 | ||
|  |     # Generation arguments | ||
|  |     gen_kwargs = { | ||
|  |         "input_ids": input_ids, | ||
|  |         "do_sample": True, | ||
|  |         "top_p": 0.7, | ||
|  |         "temperature": 0.95, | ||
|  |         "num_beams": 1, | ||
|  |         "max_new_tokens": 512, | ||
|  |         "repetition_penalty": 1.0, | ||
|  |         "logits_processor": get_logits_processor() | ||
|  |     } | ||
|  | 
 | ||
|  |     # Generate response | ||
|  |     with torch.no_grad(): | ||
|  |         generation_output = model.generate(**gen_kwargs) | ||
|  |     outputs = generation_output.tolist()[0][len(input_ids[0]):] | ||
|  |     response = tokenizer.decode(outputs, skip_special_tokens=True) | ||
|  | 
 | ||
|  |     # Update history | ||
|  |     history = history + [(prompt, response)] | ||
|  | 
 | ||
|  |     # Prepare response | ||
|  |     now = datetime.datetime.now() | ||
|  |     time = now.strftime("%Y-%m-%d %H:%M:%S") | ||
|  |     answer = { | ||
|  |         "response": repr(response), | ||
|  |         "history": repr(history), | ||
|  |         "status": 200, | ||
|  |         "time": time | ||
|  |     } | ||
|  | 
 | ||
|  |     # Log and clean up | ||
|  |     log = "[" + time + "] " + "\", prompt:\"" + prompt + "\", response:\"" + repr(response) + "\"" | ||
|  |     print(log) | ||
|  |     torch_gc() | ||
|  | 
 | ||
|  |     return answer | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     model_args, data_args, finetuning_args = prepare_infer_args() | ||
|  |     model, tokenizer = load_pretrained(model_args, finetuning_args) | ||
|  | 
 | ||
|  |     def format_example_alpaca(query, history): | ||
|  |         prompt = "Below is an instruction that describes a task. " | ||
|  |         prompt += "Write a response that appropriately completes the request.\n" | ||
|  |         prompt += "Instruction:\n" | ||
|  |         for old_query, response in history: | ||
|  |             prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) | ||
|  |         prompt += "Human: {}\nAssistant:".format(query) | ||
|  |         return prompt | ||
|  | 
 | ||
|  |     def format_example_ziya(query, history): | ||
|  |         prompt = "" | ||
|  |         for old_query, response in history: | ||
|  |             prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response) | ||
|  |         prompt += "<human>: {}\n<bot>:".format(query) | ||
|  |         return prompt | ||
|  | 
 | ||
|  |     format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya | ||
|  | 
 | ||
|  |     uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) |