import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import argparse
from typing import List, Tuple
from threading import Thread
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
# from vllm.utils import FlexibleArgumentParser
from flask import Flask, jsonify
from flask import request
import redis
import time
import json

# http接口服务
# app = FastAPI()
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False

pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=50,db=3, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)

db_key_query = 'query'
db_key_querying = 'querying'
db_key_result = 'result'
batch_size = 15

class log:
    def __init__(self):
        pass

    def log(*args, **kwargs):
        format = '%Y/%m/%d-%H:%M:%S'
        format_h = '%Y-%m-%d'
        value = time.localtime(int(time.time()))
        dt = time.strftime(format, value)
        dt_log_file = time.strftime(format_h, value)
        log_file = 'log_file/access-%s' % dt_log_file + ".log"
        if not os.path.exists(log_file):
            with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
                print(dt, *args, file=f, **kwargs)
        else:
            with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
                print(dt, *args, file=f, **kwargs)


def initialize_engine() -> LLMEngine:
    """Initialize the LLMEngine from the command line arguments."""
    # model_dir = "/home/majiahui/project/models-llm/Qwen-0_5B-Chat"
    # model_dir = "/home/majiahui/project/models-llm/openbuddy-qwen2.5llamaify-7b_train_11_prompt_mistral_gpt_xiaobiaot_real_paper"
    model_dir = "/home/majiahui/project/models-llm/openbuddy-qwen2.5llamaify-7b_train_11_prompt_mistral_gpt_xiaobiaot_real_paper_2"
    args = EngineArgs(model_dir)
    args.max_num_seqs = 16  # batch最大20条样本
    args.gpu_memory_utilization = 0.8
    args.max_model_len=8192
    # 加载模型
    return LLMEngine.from_engine_args(args)

engine = initialize_engine()


def create_test_prompts(prompt_texts, query_ids, sampling_params_list) -> List[Tuple[str,str, SamplingParams]]:
    """Create a list of test prompts with their sampling parameters."""

    return_list = []

    for i,j,k in zip(prompt_texts, query_ids, sampling_params_list):
        return_list.append((i, j, k))
    return return_list


def process_requests(engine: LLMEngine,
                     test_prompts: List[Tuple[str, str, SamplingParams]]):
    """Continuously process a list of prompts and handle the outputs."""

    return_list = []
    while test_prompts or engine.has_unfinished_requests():
        if test_prompts:
            prompt,  query_id, sampling_params = test_prompts.pop(0)
            engine.add_request(str(query_id), prompt, sampling_params)

        request_outputs: List[RequestOutput] = engine.step()

        for request_output in request_outputs:
            if request_output.finished:
                return_list.append(request_output)
    return return_list


def main(prompt_texts, query_ids, sampling_params_list):
    """Main function that sets up and runs the prompt processing."""

    test_prompts = create_test_prompts(prompt_texts, query_ids, sampling_params_list)
    return process_requests(engine, test_prompts)


# chat对话接口
# @app.route("/predict/", methods=["POST"])
# def chat():
#     # request = request.json()
#     # query = request.get('query', None)
#     # history = request.get('history', [])
#     # system = request.get('system', 'You are a helpful assistant.')
#     # stream = request.get("stream", False)
#     # user_stop_words = request.get("user_stop_words",
#     #                               [])  # list[str],用户自定义停止句,例如:['Observation: ', 'Action: ']定义了2个停止句,遇到任何一个都会停止
#
#     query = request.json['query']
#
#
#     # 构造prompt
#     # prompt_text, prompt_tokens = _build_prompt(generation_config, tokenizer, query, history=history, system=system)
#
#     prompt_text = f"<|im_start|>user\n{query}\n<|im_end|>\n<|im_start|>assistant\n"
#
#
#     return_output = main(prompt_text, sampling_params)
#     return_info = {
#         "request_id": return_output.request_id,
#         "text": return_output.outputs[0].text
#     }
#
#     return jsonify(return_info)

def classify(batch_size):  # 调用模型,设置最大batch_size
    while True:
        texts = []
        query_ids = []
        sampling_params_list = []
        if redis_.llen(db_key_query) == 0:  # 若队列中没有元素就继续获取
            time.sleep(2)
            continue

        # for i in range(min(redis_.llen(db_key_query), batch_size)):
        while True:
            query = redis_.lpop(db_key_query)  # 获取query的text
            if query == None:
                break

            query = query.decode('UTF-8')
            data_dict_path = json.loads(query)

            path = data_dict_path['path']
            with open(path, encoding='utf8') as f1:
                # 加载文件的对象
                data_dict = json.load(f1)
            # query_ids.append(json.loads(query)['id'])
            # texts.append(json.loads(query)['text'])  # 拼接若干text 为batch
            query_id = data_dict['id']
            print("query_id", query_id)
            text = data_dict["text"]
            model = data_dict["model"]
            top_p = data_dict["top_p"]
            temperature = data_dict["temperature"]
            presence_penalty = 0.8
            max_tokens = 8192
            query_ids.append(query_id)
            texts.append(text)
            # sampling_params = SamplingParams(temperature=0.3, top_p=0.5, stop="<|end|>", presence_penalty=1.1, max_tokens=8192)
            sampling_params_list.append(SamplingParams(
                temperature=temperature,
                top_p=top_p,
                stop="<|end|>",
                presence_penalty=presence_penalty,
                max_tokens=max_tokens
            ))
            if len(texts) == batch_size:
                break

        print("texts", len(texts))
        print("query_ids", len(query_ids))
        print("sampling_params_list", len(sampling_params_list))
        outputs = main(texts, query_ids, sampling_params_list)

        print("预测完成")
        generated_text_dict = {}
        print("outputs", len(outputs))
        for i, output in enumerate(outputs):
            index = output.request_id
            print(index)
            generated_text = output.outputs[0].text
            generated_text_dict[index] = generated_text

        print(generated_text_dict)
        for id_, output in generated_text_dict.items():

            return_text = {"texts": output, "probabilities": None, "status_code": 200}
            load_result_path = "./new_data_logs/{}.json".format(id_)
            with open(load_result_path, 'w', encoding='utf8') as f2:
                # ensure_ascii=False才能输入中文,否则是Unicode字符
                # indent=2 JSON数据的缩进,美观
                json.dump(return_text, f2, ensure_ascii=False, indent=4)
            redis_.set(id_, load_result_path, 86400)
            # redis_.set(id_, load_result_path, 30)
            redis_.srem(db_key_querying, id_)
            log.log('start at',
                    'query_id:{},load_result_path:{},return_text:{}'.format(
                        id_, load_result_path, return_text))


if __name__ == '__main__':
    t = Thread(target=classify, args=(batch_size,))
    t.start()