import os
from flask import Flask, jsonify
from flask import request
import operator
import torch
from transformers import BertTokenizerFast, BertForMaskedLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import uuid
import json
from threading import Thread
import time
import re
import logging
import unicodedata


logging.basicConfig(level=logging.DEBUG,  # 控制台打印的日志级别
                    filename='rewrite.log',
                    filemode='a',  ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
                    # a是追加模式,默认如果不写的话,就是追加模式
                    format=
                    '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
                    # 日志格式
                    )
db_key_query = 'query'
batch_size = 32
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
import logging

pattern = r"[。]"
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”")
fuhao_end_sentence = ["。", ",", "?", "!", "…"]

tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese")
model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese")
model.to(device)


def is_chinese(char):
    if 'CJK' in unicodedata.name(char):
        return True
    else:
        return False


class SentenceUlit:
    def __init__(self, sentence):
        self.sentence = sentence
        self.sentence_list = [""] * len(sentence)
        self.last_post = False
        self.sentence_batch = []
        self.pre_ulit()
        self.inf_sentence_batch_str = ""

    def is_chinese(self, char):
        if 'CJK' in unicodedata.name(char):
            return True
        else:
            return False

    def pre_ulit(self):
        for i, d in enumerate(self.sentence):
            bool_ = is_chinese(d)
            if bool_ == False:
                self.sentence_list[i] = d
                self.last_post = False
            else:
                if self.last_post == False:
                    self.sentence_batch.append(d)
                else:
                    self.sentence_batch[-1] += d
                self.last_post = True

    def inf_ulit(self, sen):
        for i in sen:
            self.inf_sentence_batch_str += i
        self.inf_sentence_batch_srt_list = list(self.inf_sentence_batch_str)

        for i, d in enumerate(self.sentence_list):
            if d == "":
                zi = self.inf_sentence_batch_srt_list.pop(0)
                self.sentence_list[i] = zi


class log:
    def __init__(self):
        pass

    def log(*args, **kwargs):
        format = '%Y/%m/%d-%H:%M:%S'
        format_h = '%Y-%m-%d'
        value = time.localtime(int(time.time()))
        dt = time.strftime(format, value)
        dt_log_file = time.strftime(format_h, value)
        log_file = 'log_file/access-%s' % dt_log_file + ".log"
        if not os.path.exists(log_file):
            with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
                print(dt, *args, file=f, **kwargs)
        else:
            with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
                print(dt, *args, file=f, **kwargs)

def get_errors(corrected_text, origin_text):
    sub_details = []
    for i, ori_char in enumerate(origin_text):
        if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
            # add unk word
            corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
            continue
        if i >= len(corrected_text):
            continue
        if ori_char != corrected_text[i]:
            if ori_char.lower() == corrected_text[i]:
                # pass english upper char
                corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
                continue
            sub_details.append((ori_char, corrected_text[i], i, i + 1))
    sub_details = sorted(sub_details, key=operator.itemgetter(2))
    return corrected_text, sub_details


def main(texts):
    with torch.no_grad():
        outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device))

    result = []
    print(outputs.logits)
    for ids, text in zip(outputs.logits, texts):

        _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
        corrected_text = _text[:len(text)]
        print(corrected_text)
        corrected_text, details = get_errors(corrected_text, text)
        result.append({"old": text,
                       "new": corrected_text,
                       "re_pos": details})
    return result


@app.route("/predict", methods=["POST"])
def handle_query():
    print(request.remote_addr)
    texts = request.json["texts"]
    return_list = main(texts)
    return_text = {"resilt": return_list, "probabilities": None, "status_code": 200}
    return jsonify(return_text)  # 返回结果


if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG,  # 控制台打印的日志级别
                        filename='rewrite.log',
                        filemode='a',  ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
                        # a是追加模式,默认如果不写的话,就是追加模式
                        format=
                        '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
                        # 日志格式
                        )
    app.run(host="0.0.0.0", port=16000, threaded=True, debug=False)