# -*- coding: utf-8 -*-

"""
@Time    :  2023/3/29 14:27
@Author  :
@FileName:
@Software:
@Describe:
"""

import os
from flask import Flask, jsonify, Response
from flask import request
import redis
import uuid
import json
import time
import threading
from threading import Thread
from flask import send_file, send_from_directory
import os
from flask import make_response
import openai
import base64
import re
import urllib.parse as pa


pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50, db=1)
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)

db_key_query = 'query'
db_key_querying = 'querying'
batch_size = 32

app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False

import logging
lock = threading.RLock()

mulu_prompt = "请帮我根据题目为“{}”生成一个论文目录"
first_title_prompt = "论文题目是“{}”,目录是“{}”,请把其中的大标题“{}”的内容续写完整,保证续写内容不少于1000字"
small_title_prompt = "论文题目是“{}”,目录是“{}”,请把其中的小标题“{}”的内容续写完整,保证续写内容不少于1000字"
references_prompt = "论文题目是“{}”,目录是“{}”,请为这篇论文生成中文的{}"
thank_prompt = "论文题目是“{}”,目录是“{}”,请把其中的{}部分续写完整"
kaitibaogao_prompt = "请以《{}》为题目生成研究的主要的内容、背景、目的、意义,要求不少于1500字"

dabiaoti = ["二","三","四","五","六","七","八","九"]

# 正则
pantten_second_biaoti = '[2二ⅡⅠ][、.]\s{0,}?[\u4e00-\u9fa5]+'
pantten_other_biaoti = '[2-9二三四五六七八九ⅡⅢⅣⅤⅥⅦⅧⅨ][、.]\s{0,}?[\u4e00-\u9fa5]+'

project_data_txt_path = "/home/majiahui/ChatGPT_Sever/new_data_txt"

api_key_list = ["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",
                "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",
                "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",
                "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"]

def chat_title(title, api_key):
    global lock
    # time.sleep(5)
    # return [str(i) for i in range(20)]
    openai.api_key = api_key
    res = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "user", "content": mulu_prompt.format(title)},
        ],
        temperature=0.5
    )
    lock.acquire()
    api_key_list.append(api_key)
    lock.release()
    mulu = res.choices[0].message.content
    mulu_list = str(mulu).split("\n")
    mulu_list = [i.strip() for i in mulu_list if i != ""]
    return mulu, mulu_list


def chat_kaitibaogao(title, api_key, uuid_path):
    global lock
    openai.api_key = api_key
    res = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[
            {"role": "user", "content": kaitibaogao_prompt.format(title)},
        ],
        temperature=0.5
    )
    kaitibaogao = res.choices[0].message.content
    kaitibaogao_path = os.path.join(uuid_path, "kaitibaogao.txt")
    with open(kaitibaogao_path, 'w', encoding='utf8') as f_kaitibaogao:
        f_kaitibaogao.write(kaitibaogao)
    lock.acquire()
    api_key_list.append(api_key)
    lock.release()


class GeneratePaper:
    def __init__(self, mulu, table):
        self.mulu = mulu
        self.paper = [""] * len(table)

    def chat_content_(self,api_key, mulu_title_id, title, mulu, subtitle, prompt):
        global lock
        # time.sleep(5)
        # api_key_list.append(api_key)
        # self.paper[mulu_title_id] = subtitle
        if subtitle[:2] == "@@":
            self.paper[mulu_title_id] = subtitle[2:]
        else:
            openai.api_key = api_key
            res = openai.ChatCompletion.create(
                model="gpt-3.5-turbo",
                messages=[
                    {"role": "user", "content": prompt.format(title, mulu, subtitle)},
                ],
                temperature=0.5
            )
            self.paper[mulu_title_id] = res.choices[0].message.content
        lock.acquire()
        api_key_list.append(api_key)
        lock.release()
        # return res.choices[0].message.content


def classify():  # 调用模型,设置最大batch_size
    while True:
        if redis_.llen(db_key_query) == 0:  # 若队列中没有元素就继续获取
            time.sleep(3)
            continue
        thread_list = []
        query = redis_.lpop(db_key_query).decode('UTF-8')  # 获取query的text
        data_dict_path = json.loads(query)
        query_id = data_dict_path['id']
        title = data_dict_path['title']
        # project_data_txt_path = os.path.abspath("new_data_txt")
        # uuid_path = "new_data_txt/{}/".format(query_id)

        # uuid路径
        uuid_path = os.path.join(project_data_txt_path, query_id)
        print("uuid",query_id)
        os.makedirs(uuid_path)
        print("uuid_path", os.path.exists(uuid_path))

        # 生成开题报告
        # title, api_key, uuid_path
        api_key = api_key_list.pop()
        t = Thread(target=chat_kaitibaogao, args=(title,
                                                api_key,
                                                uuid_path,
                                                ))
        t.start()
        thread_list.append(t)

        # 生成目录
        while True:
            if api_key_list != []:
                api_key = api_key_list.pop()
                break
            else:
                time.sleep(3)


        mulu, mulu_list = chat_title(title, api_key)


        # mulu_base64 = base64.b64encode(mulu.encode('utf-8'))
        # mulu_path = os.path.join(uuid_path, "mulu.txt")
        # with open(mulu_path, 'wb', encoding='utf8') as f2:
        #     f2.write(mulu_base64)


        index = 0
        print(mulu_list)

        cun_bool = False
        table_of_contents = [mulu_list[0]]

        for i in mulu_list[1:]:
            result_second_biaoti_list = re.findall(pantten_second_biaoti, i)
            result_other_biaoti_list = re.findall(pantten_other_biaoti, i)
            if result_second_biaoti_list != []:
                table_of_contents.append("@@" + i)
                cun_bool = True
                continue
            if cun_bool == False:
                continue
            else:
                if result_other_biaoti_list != []:
                    table_of_contents.append("@@" + i)
                else:
                    table_of_contents.append(i)

        print(table_of_contents)
        # table_of_contents = table_of_contents[:3] + table_of_contents[-1:]
        # print(table_of_contents)
        chat_class = GeneratePaper(mulu_list, table_of_contents)
        print(len(table_of_contents))
        ############################################################
        while True:
            if api_key_list == []:
                continue
            if index == len(table_of_contents):
                break
            api_key = api_key_list.pop()
            subtitle = table_of_contents[index]
            if index == 0:
                prompt = first_title_prompt
            elif subtitle == "参考文献":
                prompt = references_prompt
            elif subtitle == "致谢":
                prompt = thank_prompt
            else:
                prompt = first_title_prompt
            print("请求的所有参数", api_key,
                                  index,
                                  title,
                                  mulu_list,
                                  subtitle,
                                  prompt)

            t = Thread(target=chat_class.chat_content_, args=(api_key,
                                                              index,
                                                              title,
                                                              mulu_list,
                                                              subtitle,
                                                              prompt))
            t.start()
            thread_list.append(t)
            lock.acquire()
            index += 1
            lock.release()

        for thread in thread_list:
            thread.join()


        print(chat_class.paper)
        paper = "\n".join(chat_class.paper)
        print(paper)

        content_path = os.path.join(uuid_path, "content.txt")
        with open(content_path, 'w', encoding='utf8') as f_content:
            f_content.write(paper)

        mulu_path = os.path.join(uuid_path, "mulu.txt")
        with open(mulu_path, 'w', encoding='utf8') as f_mulu:
            f_mulu.write(mulu)

        kaitibaogao_txt_path = os.path.join(uuid_path, "kaitibaogao.txt")

        # word保存路径

        save_word_paper = os.path.join(uuid_path, "paper.docx")
        save_word_paper_start = os.path.join(uuid_path, "paper_start.docx".format(title))

        # content_base64 = base64.b64encode(paper.encode('utf-8'))
        # content_path = os.path.join(uuid_path, "content.txt")
        # with open(content_path, 'wb', encoding='utf8') as f2:
        #     f2.write(content_base64)

        # 拼接成word
        title = pa.quote(title)
        mulu_path = mulu_path
        content_path = content_path

        # 调用jar包
        print("java_path", mulu_path, content_path, title, save_word_paper)
        os.system(
            "java -Dfile.encoding=UTF-8 -jar '/home/majiahui/ChatGPT_Sever/createAiXieZuoWord.jar' '{}' '{}' '{}' '{}'".format(
                mulu_path, content_path, title, save_word_paper))

        print("jaba_kaitibaogao", kaitibaogao_txt_path, save_word_paper_start)
        os.system("java -Dfile.encoding=UTF-8 -jar '/home/majiahui/ChatGPT_Sever/createAiXieZuoKaitiWord.jar' '{}' '{}'".format(
                kaitibaogao_txt_path, save_word_paper_start))

        url_path_paper = "http://104.244.90.248:14000/download?filename_path={}/paper.docx".format(query_id)
        url_path_kaiti = "http://104.244.90.248:14000/download?filename_path={}/paper_start.docx".format(query_id)
        # content_path = os.path.join(uuid_path, "content.txt")
        # load_result_path = res_path.format(query_id)
        # load_result_path = os.path.abspath(load_result_path)
        # with open(load_result_path, 'w', encoding='utf8') as f2:
        #     f2.write(paper)

        return_text = str({"id":query_id,
                           "content_url_path": url_path_paper,
                           "content_report_url_path": url_path_kaiti,
                           "probabilities": None,
                           "status_code": 200})
        redis_.srem(db_key_querying, query_id)
        redis_.set(query_id, return_text, 28800)


@app.route("/chat", methods=["POST"])
def chat():
    print(request.remote_addr)
    title = request.json["title"]
    id_ = str(uuid.uuid1())

    redis_.rpush(db_key_query, json.dumps({"id":id_, "title": title}))  # 加入redis
    return_text = {"texts": {'id': id_,}, "probabilities": None, "status_code": 200}
    print("ok")
    redis_.sadd(db_key_querying, id_)

    return jsonify(return_text)  # 返回结果


@app.route("/download", methods=['GET'])
def download_file():
    # 需要知道2个参数, 第1个参数是本地目录的path, 第2个参数是文件名(带扩展名)
    # directory = os.path.join(project_data_txt_path, filename)  # 假设在当前目录

    # uuid_path, word_name = str(filename).split("/")
    # word_path_root = os.path.join(project_data_txt_path, uuid_path)
    # response = make_response(send_from_directory(word_path_root, word_name, as_attachment=True))
    # response.headers["Content-Disposition"] = "attachment; filename={}".format(filename.encode().decode('latin-1'))
    filename_path = request.args.get('filename_path', '')
    filename = filename_path.split("/")[1]
    path_name = os.path.join(project_data_txt_path, filename_path)
    with open(path_name, 'rb') as f:
        stream = f.read()
    response = Response(stream, content_type='application/octet-stream')
    response.headers['Content-disposition'] = 'attachment; filename={}'.format(filename)

    return response


@app.route("/search", methods=["POST"])
def search():
    id_ = request.json['id']  # 获取用户query中的文本 例如"I love you"
    result = redis_.get(id_)  # 获取该query的模型结果
    if result is not None:
        # redis_.delete(id_)
        # result_dict = result.decode('UTF-8')

        result_dict = eval(result)
        # return_text = {"id":query_id, "load_result_path": load_result_path, "probabilities": None, "status_code": 200}
        query_id = result_dict["id"]
        # "content_url_path": url_path_paper,
        # "content_report_url_path": url_path_kaiti,
        content_url_path = result_dict["content_url_path"]
        content_report_url_path = result_dict["content_report_url_path"]
        probabilities = result_dict["probabilities"]
        result_text = {'code': 200,
                       'content_url_path': content_url_path,
                       'content_report_url_path': content_report_url_path,
                       'probabilities': probabilities}
    else:
        querying_list = list(redis_.smembers("querying"))
        querying_set = set()
        for i in querying_list:
            querying_set.add(i.decode())

        querying_bool = False
        if id_ in querying_set:
            querying_bool = True

        query_list_json = redis_.lrange(db_key_query, 0, -1)
        query_set_ids = set()
        for i in query_list_json:
            data_dict = json.loads(i)
            query_id = data_dict['id']
            query_set_ids.add(query_id)

        query_bool = False
        if id_ in query_set_ids:
            query_bool = True

        if querying_bool == True and query_bool == True:
            result_text = {'code': "201", 'text': "", 'probabilities': None}
        elif querying_bool == True and query_bool == False:
            result_text = {'code': "202", 'text': "", 'probabilities': None}
        else:
            result_text = {'code': "203", 'text': "", 'probabilities': None}
    return jsonify(result_text)  # 返回结果

t = Thread(target=classify)
t.start()

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=14000, threaded=True, debug=False)