You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
155 lines
4.0 KiB
155 lines
4.0 KiB
import json
|
|
import os
|
|
import re
|
|
|
|
os.environ["WANDB_DISABLED"] = "true"
|
|
|
|
|
|
# 设置CUDA设备
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
|
|
|
|
|
|
import logging
|
|
import os
|
|
import random
|
|
import sys
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
|
|
import datasets
|
|
import evaluate
|
|
import numpy as np
|
|
from datasets import load_dataset
|
|
import torch
|
|
import transformers
|
|
from transformers import (
|
|
AutoConfig,
|
|
AutoModelForSequenceClassification,
|
|
AutoTokenizer,
|
|
DataCollatorWithPadding,
|
|
EvalPrediction,
|
|
HfArgumentParser,
|
|
PretrainedConfig,
|
|
Trainer,
|
|
TrainingArguments,
|
|
default_data_collator,
|
|
set_seed,
|
|
BertTokenizer,
|
|
BertModel
|
|
)
|
|
from transformers import BigBirdModel
|
|
from transformers.trainer_utils import get_last_checkpoint
|
|
from transformers.utils import check_min_version, send_example_telemetry
|
|
from transformers.utils.versions import require_version
|
|
|
|
from BertClsModel import BertForSequenceClassification
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
|
|
|
|
def load_model(model_path: str):
|
|
config = AutoConfig.from_pretrained(
|
|
model_path,
|
|
num_labels=4,
|
|
)
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(
|
|
model_path
|
|
)
|
|
|
|
model = BertForSequenceClassification.from_pretrained(
|
|
model_path,
|
|
config=config
|
|
)
|
|
return model, tokenizer
|
|
|
|
id_2_lable = {
|
|
0: "正文",
|
|
1: "一级标题",
|
|
2: "二级标题",
|
|
3: "三级标题",
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model, tokenizer = load_model(model_path='/home/majiahui/project/text-classification-long/long_paper_1')
|
|
|
|
# text = "(1)经病理学或细胞学确诊的肺癌患者;"
|
|
#
|
|
# sen = [text]
|
|
# result = tokenizer(sen, max_length=512, truncation=True)
|
|
# print(result)
|
|
#
|
|
# input_ids = result['input_ids']
|
|
# token_type_ids = result['token_type_ids']
|
|
#
|
|
# input_ids = seq_padding(tokenizer, input_ids)
|
|
# token_type_ids = seq_padding(tokenizer, token_type_ids)
|
|
#
|
|
#
|
|
# result = model(input_ids=input_ids,token_type_ids=token_type_ids) # 这里不需要labels
|
|
# output = torch.sigmoid(result[0][0]).tolist()
|
|
# # result_ = result[0][0]
|
|
# print(output)
|
|
|
|
model.to("cuda")
|
|
data_list = pd.read_csv("data/long_paper_2.csv").values.tolist()
|
|
data_new = []
|
|
|
|
zong = 0
|
|
rel = 0
|
|
jishu = 0
|
|
|
|
for i in tqdm(data_list):
|
|
# print(zong)
|
|
# print(i)
|
|
zong += 1
|
|
text = i[0]
|
|
lable = i[1]
|
|
|
|
result = tokenizer([text], max_length=2048, truncation=True)
|
|
|
|
input_ids = result['input_ids']
|
|
token_type_ids = result['token_type_ids']
|
|
# print(input_ids)
|
|
# print(text)
|
|
# print(lable)
|
|
input_ids = torch.tensor(input_ids) # 将列表转换为 PyTorch tensor
|
|
token_type_ids = torch.tensor(token_type_ids) # 将列表转换为 PyTorch tensor
|
|
input_ids = input_ids.long()
|
|
token_type_ids = token_type_ids.long()
|
|
|
|
batch_masks = input_ids.gt(0).to("cuda")
|
|
input_ids, token_type_ids = input_ids.to("cuda"), token_type_ids.to("cuda")
|
|
result = model(input_ids=input_ids,token_type_ids=token_type_ids, attention_mask=batch_masks) # 这里不需要labels
|
|
# output = torch.sigmoid(result[0][0]).tolist()
|
|
# # result_ = result[0][0]
|
|
# if output[1] > 0.5:
|
|
# rel += 1
|
|
#
|
|
# data_new.append({
|
|
# "index": index,
|
|
# "text": text,
|
|
# "acc": output,
|
|
# })
|
|
|
|
output = torch.sigmoid(result[0]).tolist()
|
|
# print(output)
|
|
|
|
# if output[0][0] > 0.60:
|
|
# predict_lable = 0
|
|
# else:
|
|
# predict_lable = 1
|
|
|
|
max_index = max(enumerate(output[0]), key=lambda x: x[1])[0]
|
|
# print(max_index) # 输出最大值的下标
|
|
jishu +=1
|
|
data_new.append(json.dumps({
|
|
"text": text,
|
|
"label": id_2_lable[max_index],
|
|
}, ensure_ascii=False))
|
|
print(len(data_new))
|
|
with open("data/data_title_content.jsonl", "a", encoding="utf-8") as f:
|
|
for i in data_new:
|
|
f.write(i)
|
|
f.write("\n")
|
|
|