排版识别标题级别和正文
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

155 lines
4.0 KiB

import json
import os
import re
os.environ["WANDB_DISABLED"] = "true"
# 设置CUDA设备
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import logging
import os
import random
import sys
from dataclasses import dataclass, field
from typing import Optional
import datasets
import evaluate
import numpy as np
from datasets import load_dataset
import torch
import transformers
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
EvalPrediction,
HfArgumentParser,
PretrainedConfig,
Trainer,
TrainingArguments,
default_data_collator,
set_seed,
BertTokenizer,
BertModel
)
from transformers import BigBirdModel
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from BertClsModel import BertForSequenceClassification
import pandas as pd
from tqdm import tqdm
def load_model(model_path: str):
config = AutoConfig.from_pretrained(
model_path,
num_labels=4,
)
tokenizer = BertTokenizer.from_pretrained(
model_path
)
model = BertForSequenceClassification.from_pretrained(
model_path,
config=config
)
return model, tokenizer
id_2_lable = {
0: "正文",
1: "一级标题",
2: "二级标题",
3: "三级标题",
}
if __name__ == "__main__":
model, tokenizer = load_model(model_path='/home/majiahui/project/text-classification-long/long_paper_1')
# text = "(1)经病理学或细胞学确诊的肺癌患者;"
#
# sen = [text]
# result = tokenizer(sen, max_length=512, truncation=True)
# print(result)
#
# input_ids = result['input_ids']
# token_type_ids = result['token_type_ids']
#
# input_ids = seq_padding(tokenizer, input_ids)
# token_type_ids = seq_padding(tokenizer, token_type_ids)
#
#
# result = model(input_ids=input_ids,token_type_ids=token_type_ids) # 这里不需要labels
# output = torch.sigmoid(result[0][0]).tolist()
# # result_ = result[0][0]
# print(output)
model.to("cuda")
data_list = pd.read_csv("data/long_paper_2.csv").values.tolist()
data_new = []
zong = 0
rel = 0
jishu = 0
for i in tqdm(data_list):
# print(zong)
# print(i)
zong += 1
text = i[0]
lable = i[1]
result = tokenizer([text], max_length=2048, truncation=True)
input_ids = result['input_ids']
token_type_ids = result['token_type_ids']
# print(input_ids)
# print(text)
# print(lable)
input_ids = torch.tensor(input_ids) # 将列表转换为 PyTorch tensor
token_type_ids = torch.tensor(token_type_ids) # 将列表转换为 PyTorch tensor
input_ids = input_ids.long()
token_type_ids = token_type_ids.long()
batch_masks = input_ids.gt(0).to("cuda")
input_ids, token_type_ids = input_ids.to("cuda"), token_type_ids.to("cuda")
result = model(input_ids=input_ids,token_type_ids=token_type_ids, attention_mask=batch_masks) # 这里不需要labels
# output = torch.sigmoid(result[0][0]).tolist()
# # result_ = result[0][0]
# if output[1] > 0.5:
# rel += 1
#
# data_new.append({
# "index": index,
# "text": text,
# "acc": output,
# })
output = torch.sigmoid(result[0]).tolist()
# print(output)
# if output[0][0] > 0.60:
# predict_lable = 0
# else:
# predict_lable = 1
max_index = max(enumerate(output[0]), key=lambda x: x[1])[0]
# print(max_index) # 输出最大值的下标
jishu +=1
data_new.append(json.dumps({
"text": text,
"label": id_2_lable[max_index],
}, ensure_ascii=False))
print(len(data_new))
with open("data/data_title_content.jsonl", "a", encoding="utf-8") as f:
for i in data_new:
f.write(i)
f.write("\n")