import json
import os
import re

os.environ["WANDB_DISABLED"] = "true"


# 设置CUDA设备
os.environ['CUDA_VISIBLE_DEVICES'] = '1'


import logging
import os
import random
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
import evaluate
import numpy as np
from datasets import load_dataset
import torch
import transformers
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
    BertTokenizer,
    BertModel
)
from transformers import BigBirdModel
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version

from BertClsModel import BertForSequenceClassification
import pandas as pd
from tqdm import tqdm


def load_model(model_path: str):
    config = AutoConfig.from_pretrained(
        model_path,
        num_labels=4,
    )

    tokenizer = BertTokenizer.from_pretrained(
        model_path
    )

    model = BertForSequenceClassification.from_pretrained(
        model_path,
        config=config
    )
    return model, tokenizer

id_2_lable = {
    0: "正文",
    1: "一级标题",
    2: "二级标题",
    3: "三级标题",
}


if __name__ == "__main__":
    model, tokenizer = load_model(model_path='/home/majiahui/project/text-classification-long/long_paper_1')

    # text = "(1)经病理学或细胞学确诊的肺癌患者;"
    #
    # sen = [text]
    # result = tokenizer(sen, max_length=512, truncation=True)
    # print(result)
    #
    # input_ids = result['input_ids']
    # token_type_ids = result['token_type_ids']
    #
    # input_ids = seq_padding(tokenizer, input_ids)
    # token_type_ids = seq_padding(tokenizer, token_type_ids)
    #
    #
    # result = model(input_ids=input_ids,token_type_ids=token_type_ids)  # 这里不需要labels
    # output = torch.sigmoid(result[0][0]).tolist()
    # # result_ = result[0][0]
    # print(output)

    model.to("cuda")
    data_list = pd.read_csv("data/long_paper_2.csv").values.tolist()
    data_new = []

    zong = 0
    rel = 0
    jishu = 0

    for i in tqdm(data_list):
        # print(zong)
        # print(i)
        zong += 1
        text = i[0]
        lable = i[1]

        result = tokenizer([text], max_length=2048, truncation=True)

        input_ids = result['input_ids']
        token_type_ids = result['token_type_ids']
        # print(input_ids)
        # print(text)
        # print(lable)
        input_ids = torch.tensor(input_ids)  # 将列表转换为 PyTorch tensor
        token_type_ids = torch.tensor(token_type_ids)  # 将列表转换为 PyTorch tensor
        input_ids = input_ids.long()
        token_type_ids = token_type_ids.long()

        batch_masks = input_ids.gt(0).to("cuda")
        input_ids, token_type_ids = input_ids.to("cuda"), token_type_ids.to("cuda")
        result = model(input_ids=input_ids,token_type_ids=token_type_ids, attention_mask=batch_masks)  # 这里不需要labels
        # output = torch.sigmoid(result[0][0]).tolist()
        # # result_ = result[0][0]
        # if output[1] > 0.5:
        #     rel += 1
        #
        # data_new.append({
        #     "index": index,
        #     "text": text,
        #     "acc": output,
        # })

        output = torch.sigmoid(result[0]).tolist()
        # print(output)

        # if output[0][0] > 0.60:
        #     predict_lable = 0
        # else:
        #     predict_lable = 1

        max_index = max(enumerate(output[0]), key=lambda x: x[1])[0]
        # print(max_index)  # 输出最大值的下标
        jishu +=1
        data_new.append(json.dumps({
            "text": text,
            "label": id_2_lable[max_index],
        }, ensure_ascii=False))
    print(len(data_new))
    with open("data/data_title_content.jsonl", "a", encoding="utf-8") as f:
        for i in data_new:
            f.write(i)
            f.write("\n")