# -*- coding = utf-8 -*-
# @Time:  16:41
# @Author:ZYP
# @File:LoadRoformer.py
# @mail:zypsunshine1@gmail.com
# @Software: PyCharm

# =========================================================================================
# 加载深度学习模型
# · 加载论文分类模型
# · 加载 BERT 模型
# =========================================================================================

import json
import numpy as np
from bert4keras.models import build_transformer_model
from keras.layers import Lambda, Dense
from keras.models import Model
from bert4keras.tokenizers import Tokenizer


def load_roformer_model(config, ckpt, model_weight_path):
    """加载训练好的168多标签分类模型"""
    roformer = build_transformer_model(
        config_path=config,
        checkpoint_path=ckpt,
        model='roformer_v2',
        return_keras_model=False)
    output = Lambda(lambda x: x[:, 0])(roformer.model.output)
    output = Dense(
        units=class_nums,
        kernel_initializer=roformer.initializer
    )(output)
    model1 = Model(roformer.model.input, output)
    model1.load_weights(model_weight_path)
    model1.summary()
    return model1


def load_label(label_path1):
    """加载label2id、id2label、每个类别的阈值,用于分类"""
    with open(label_path1, 'r', encoding='utf-8') as f:
        json_dict = json.load(f)
    label2id1 = {i: j[0] for i, j in json_dict.items()}
    id2label1 = {j[0]: i for i, j in json_dict.items()}
    label_threshold1 = np.array([j[1] for i, j in json_dict.items()])
    return label2id1, id2label1, label_threshold1


def encode(text_list1):
    """将文本列表进行循环编码"""
    sent_token_id1, sent_segment_id1 = [], []
    for index, text in enumerate(text_list1):
        if index == 0:
            token_id, segment_id = tokenizer_roformer.encode(text)
        else:
            token_id, segment_id = tokenizer_roformer.encode(text)
            token_id = token_id[1:]
            segment_id = segment_id[1:]
        if (index + 1) % 2 == 0:
            segment_id = [1] * len(token_id)
        sent_token_id1 += token_id
        sent_segment_id1 += segment_id
    if len(sent_token_id1) > max_len:
        sent_token_id1 = sent_token_id1[:max_len]
        sent_segment_id1 = sent_segment_id1[:max_len]

    sent_token_id = np.array([sent_token_id1])
    sent_segment_id = np.array([sent_segment_id1])
    return sent_token_id, sent_segment_id


def load_bert_model(config, ckpt, model_weight_path):
    """加载 BERT 模型"""
    bert = build_transformer_model(
        config_path=config,
        checkpoint_path=ckpt,
        model='bert',
        return_keras_model=False)
    output = Lambda(lambda x: x[:, 0])(bert.model.output)

    model1 = Model(bert.model.input, output)
    model1.load_weights(model_weight_path)
    model1.summary()
    return model1


def return_sent_vec(sent_list):
    """
    使用 bert 模型,将句子列表转化为 句子向量
    :param sent_list: 句子的列表
    :return: 返回两个值(句子的列表,对应的句子向量列表)
    """
    sent_vec_list = []

    for sent in sent_list:
        token_ids, segment_ids = tokenizer_bert.encode(sent, maxlen=512)
        sent_vec = bert_model.predict([np.array([token_ids]), np.array([segment_ids])])
        sent_vec_list.append(sent_vec[0].tolist())

    return sent_list, sent_vec_list


def pred_class_num(target_paper_dict):
    """将分类的预测结果进行返回,返回对应库的下标,同时对送检论文的要求处理成字典形式,包括 title、key_words、abst_zh、content 等"""
    text_list1 = [target_paper_dict['title'], target_paper_dict['key_words']]
    abst_zh = target_paper_dict['abst_zh']

    if len(abst_zh.split("。")) <= 10:
        text_list1.append(abst_zh)
    else:
        text_list1.append("。".join(abst_zh.split('。')[:5]))
        text_list1.append("。".join(abst_zh.split('。')[-5:]))

    sent_token, segment_ids = encode(text_list1)
    y_pred = model_roformer.predict([sent_token, segment_ids])
    idx = np.where(y_pred[0] > label_threshold, 1, 0)
    pred_label_num = [index for index, i in enumerate(idx) if i == 1]
    return pred_label_num


# =========================================================================================================================
# roformer 模型的参数
# =========================================================================================================================
class_nums = 168
max_len = 1500
roformer_config_path = ''
roformer_ckpt_path = ''
roformer_vocab_path = ''
roformer_model_weights_path = ''
label_path = '../data/label_threshold.txt'

# roformer 模型的分词器
tokenizer_roformer = Tokenizer(roformer_vocab_path)
# 加载label的相关信息
label2id, id2label, label_threshold = load_label(label_path)
# 加载训练后的分类模型
model_roformer = load_roformer_model(roformer_config_path, roformer_ckpt_path, roformer_model_weights_path)

# =========================================================================================================================
# bert 模型的参数
# =========================================================================================================================
bert_config_path = ''
bert_ckpt_path = ''
bert_vocab_path = ''
bert_model_weight_path = ''

# bert 模型的分词器
tokenizer_bert = Tokenizer(bert_vocab_path)
# 加载 bert 模型进行提取句向量
bert_model = load_bert_model(bert_config_path, bert_ckpt_path, bert_model_weight_path)