# -*- coding: utf-8 -*-

"""
@Time    :  2023/1/31 19:02
@Author  :
@FileName:
@Software:
@Describe:
"""
import os
# os.environ["TF_KERAS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import numpy as np
from bert4keras.backend import keras, set_gelu
from bert4keras.tokenizers import Tokenizer, load_vocab
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open
from keras.layers import Lambda, Dense
import tensorflow as tf
from keras.backend import set_session
from sklearn.metrics.pairwise import cosine_similarity
from rouge import Rouge  # pip install rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from tqdm import tqdm
import jieba
from gensim.models import KeyedVectors, word2vec, Word2Vec
import random
import difflib
import re

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config)) # 此处不同

class Word2vecModel:
    def __init__(self):
        self.path = "E:\pycharm_workspace\查重分析\word2vec_model\\word2vec_add_new_18.model"
        self.model = Word2Vec.load(self.path)

    def word2vec_res(self,seg_0_list, seg_1_list):
        sentence_0_list = []
        sentence_1_list = []
        for i in seg_0_list:
            a = self.model.wv[i]
            sentence_0_list.append(a)

        for i in seg_1_list:
            a = self.model.wv[i]
            sentence_1_list.append(a)

        return sentence_0_list, sentence_1_list

class Evaluator(keras.callbacks.Callback):
    """评估与保存
    """

    def __init__(self):
        self.rouge = Rouge()
        self.smooth = SmoothingFunction().method1
        self.best_bleu = 0.

    # def on_epoch_end(self, epoch, logs=None):
    #     metrics = self.evaluate(valid_data)  # 评测模型
    #     if metrics['bleu'] > self.best_bleu:
    #         self.best_bleu = metrics['bleu']
    #         model.save_weights('./best_model.weights')  # 保存模型
    #     metrics['best_bleu'] = self.best_bleu
    #     print('valid_data:', metrics)


    def evaluate_t(self, data_1, data_2, topk=1):
        total = 0
        rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0

        scores = self.rouge.get_scores(hyps=[data_1], refs=[data_2])
        rouge_1 += scores[0]['rouge-1']['f']
        rouge_2 += scores[0]['rouge-2']['f']
        rouge_l += scores[0]['rouge-l']['f']
        bleu += sentence_bleu(
            references=[data_1.split(' ')],
            hypothesis=data_2.split(' '),
            smoothing_function=self.smooth
        )
        # rouge_1 /= total
        # rouge_2 /= total
        # rouge_l /= total
        # bleu /= total
        return [rouge_1, rouge_2, rouge_l, bleu]

class bertModel:
    def __init__(self):

        # modelpath = "E:\pycharm_workspace\premodel\keras\chinese_simbert_L-12_H-768_A-12"
        # modelpath = "E:\pycharm_workspace\premodel\keras\chinese_roberta_wwm_ext_L-12_H-768_A-12"
        # modelpath = "E:\pycharm_workspace\premodel\keras\chinese_L-12_H-768_A-12"
        modelpath = "/home/majiahui/project/models-llm/keras/chinese_L-12_H-768_A-12"
        self.config_path = modelpath + r'/bert_config.json'
        self.checkpoint_path = modelpath + r'/bert_model.ckpt'
        self.dict_path = modelpath + r'/vocab.txt'
        self.token_dict, self.keep_tokens = load_vocab(
            dict_path=self.dict_path,
            simplified=True,
            startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
        )
        self.tokenizer = Tokenizer(self.token_dict, do_lower_case=True)
        self.buildmodel()


    def buildmodel(self):
        bert = build_transformer_model(
            config_path=self.config_path,
            checkpoint_path=self.checkpoint_path,
            return_keras_model=False,
        )

        output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output)
        self.model = keras.models.Model(bert.model.input, output)
        self.model.summary()

    def predict(self,text):
        batch_token_ids, batch_segment_ids = [], []
        token_ids, segment_ids = self.tokenizer.encode(text, maxlen=256)
        batch_token_ids.append(token_ids)
        batch_segment_ids.append(segment_ids)
        return self.model.predict([batch_token_ids, batch_segment_ids])

    def predict_batch(self,text_list):
        batch_token_ids, batch_segment_ids = [], []

        for t in text_list:
            token_ids, segment_ids = self.tokenizer.encode(t, maxlen=256)
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)

        batch_token_ids = sequence_padding(batch_token_ids)
        batch_segment_ids = sequence_padding(batch_segment_ids)
        return self.model.predict([batch_token_ids, batch_segment_ids])

def simbert(data_1, data_2):
    pass

def word2vec():
    pass

def bleu():
    pass

def bool_len_strsim(data_1, data_2):
    str_sim_value = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio()
    if len(data_2) - len(data_1) < 0:
        if len(data_2) / len(data_1) > 0.8:
            num_yu = 1 - len(data_2) / len(data_1)
            str_sim_value = 1 - str_sim_value * num_yu
        else:
            return False, ""

    if str_sim_value < 0.65:
        return True, str_sim_value
    else:
        return False, ""


def has_numbers(input_string):
    return any(char.isdigit() for char in input_string)


def bool_num(data_1, data_2):
    bool_1 = has_numbers(data_1)
    bool_2 = has_numbers(data_2)
    if bool_1 == True and bool_2 == True:
        return True
    else:
        return False

def is_contains_english(str):
    my_re = re.compile(r'[A-Za-z]', re.S)
    res = re.findall(my_re, str)
    if len(res):
        return True
    else:
        return False


def is_contains_kongge(str):
    if " " in str or "\t" in str:
        return True
    else:
        return False

if __name__ == '__main__':
    file = "../data/train_yy_pre.txt"
    # file = "../data/train_yy_zong_sim_99.txt"
    model = bertModel()
    eval_class = Evaluator()
    data_new = []

    data_1_list = []
    data_2_list = []

    # word2vecmodel = Word2vecModel()
    try:
        with open(file, 'r', encoding="utf-8") as f:
            lines = [x.strip() for x in f if x.strip() != '']
    except:
        with open(file, 'r', encoding="gbk") as f:
            lines = [x.strip() for x in f if x.strip() != '']

    bertsim_list = []
    bleusim_list = []
    word2vecsim_list = []
    data_train_text = []

    # random.shuffle(lines)
    print(len(lines))
    for txt in tqdm(lines):

        text = txt.split('\t')
        if len(text) == 3:
            data_1 = text[0]
            data_2 = text[2]

            # 判断是否包含数字
            bool_num_ = bool_num(data_1, data_2)
            if bool_num_ == False:
                continue

            # 判断是否包含英文
            # data_english_bool = is_contains_english(data_1)
            # if data_english_bool == True:
            #     continue

            # 判断是否包含空格
            data_kongge_bool = is_contains_kongge(data_1)
            if data_kongge_bool == True:
                continue

            # 判断是否符合字符相似度标准
            bool_len_strsim_v, strsim = bool_len_strsim(data_1,data_2)
            if bool_len_strsim_v == True:
                continue

            # # 第一种方法
            # y1 = model.predict(data_1)[0]
            # y2 = model.predict(data_2)[0]
            # cos_sim = cosine_similarity(y1.reshape(1, -1), y2.reshape(1, -1))
            # # bertsim_list.append((cos_sim[0][0], strsim, data_1, data_2))
            # if cos_sim[0][0] > 0.9:
            #     cos_sim_bool = True
            # else:
            #     cos_sim_bool = False
            #
            # if cos_sim_bool == False:
            #     continue
            #
            # data_new.append("\t".join([data_1, "to", data_2]))


                # data_train_text.append("\t".join([data_1, "to", data_2]))

            # 第二种方法
            y = model.predict_batch([data_1, data_2])
            y1 = y[0]
            y2 = y[1]
            cos_sim = cosine_similarity(y1.reshape(1, -1), y2.reshape(1, -1))
            # bertsim_list.append((cos_sim[0][0], strsim, data_1, data_2))
            if cos_sim[0][0] > 0.9:
                cos_sim_bool = True
            else:
                cos_sim_bool = False

            if cos_sim_bool == False:
                continue

            data_new.append("\t".join([data_1, "to", data_2]))



    # bertsim_list.sort(reverse=True)
    # with open("../data/tongji_len_strsim_nertsim_1.txt", "w", encoding="utf-8") as f:
    #     for i in bertsim_list:
    #         f.write(str(i[0]))
    #         f.write(str("\t"))
    #         f.write(str(i[1]))
    #         f.write(str("\t"))
    #         f.write(str(i[2]))
    #         f.write(str("\t"))
    #         f.write(str(i[3]))
    #         f.write("\n")
    # print(len(data_train_text))
    fileName = '../data/train_new/train_yy_1.txt'
    # fileName = '../data/train_new/train_yy.txt'
    with open(fileName, 'w', encoding='utf-8') as f:
        for i in data_new:
            f.write(str(i) + '\n')
        f.close()