排版识别标题级别和正文
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

250 lines
7.6 KiB

import json
import os
import re
# os.environ["WANDB_DISABLED"] = "true"
# 设置CUDA设备
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import logging
import os
import random
import sys
from dataclasses import dataclass, field
from typing import Optional
import datasets
import evaluate
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import transformers
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
EvalPrediction,
HfArgumentParser,
PretrainedConfig,
Trainer,
TrainingArguments,
default_data_collator,
set_seed,
BertTokenizer,
BertModel,
BigBirdTokenizer,
BigBirdForSequenceClassification
)
import torch
# 检查GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
lable_2_id_fenji = {
"标题": 0,
"正文": 1,
"无用类别": 2
}
# lable_2_id_fenji = {
# "标题+正文": 0,
# "无用类别": 1
# }
tokenizer = AutoTokenizer.from_pretrained(
"data_zong_shout_3",
use_fast=True,
revision="main",
trust_remote_code=False,
)
model_name = "data_zong_shout_3"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_fenji),
revision="main",
trust_remote_code=False
)
model_roberta = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_zong_no_start_shout_3"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_fenji),
revision="main",
trust_remote_code=False
)
model_roberta_no_start = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_zong_no_end_shout_3"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_fenji),
revision="main",
trust_remote_code=False
)
model_roberta_no_end = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
# lable_2_id_fenji = {
# "标题": 0,
# "正文": 1,
# "无用类别": 2
# }
id_2_lable = {}
for i in lable_2_id_fenji:
if lable_2_id_fenji[i] not in id_2_lable:
id_2_lable[lable_2_id_fenji[i]] = i
import pandas as pd
import torch
data_zong_1 = pd.read_csv("data/data_zong_dev_2.csv").values.tolist()
# data_zong_2 = pd.read_csv("data/data_zong_train.csv").values.tolist()
# data_zong = data_zong_1 + data_zong_2
data_zong = data_zong_1
output_text = ""
output_text_false = ""
zong_len = 0
zhengque = 0
tongji_false = {}
data_zong_true = []
data_zong_false = []
for i in tqdm(range(len(data_zong))):
zong_len += 1
paper = data_zong[i][0]
lable_true = data_zong[i][1]
paper_new_start_1, paper_new_start_2 = paper.split("<Start>")
paper_new_end_1, paper_new_end_2 = paper.split("<End>")
# 视野前后7句
paper_new_start_1_jiequ = "\n".join(paper_new_start_1.strip("\n").split("\n")[-7:])
paper_new_end_2_jiequ = "\n".join(paper_new_end_2.strip("\n").split("\n")[:7])
paper_object = "<Start>" + paper_new_start_2.split("\n")[0]
paper_zhong = "\n".join([paper_new_start_1_jiequ, paper_object, paper_new_end_2_jiequ])
# 视野前15句
paper_new_start_1_jiequ = "\n".join(paper_new_start_1.strip("\n").split("\n")[-15:])
paper_qian = "\n".join([paper_new_start_1_jiequ, paper_object])
# 视野后15句
paper_new_end_2_jiequ = "\n".join(paper_new_end_2.strip("\n").split("\n")[:15])
paper_hou = "\n".join([paper_object, paper_new_end_2_jiequ])
# 目标句子在中间预测结果
sentence_list = [paper_zhong]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_roberta(**result_on_device)
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item()
#
sentence_list = [paper_qian]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_roberta_no_end(**result_on_device)
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item()
sentence_list = [paper_hou]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_roberta_no_start(**result_on_device)
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item()
output_text += paper
output_text += "\n"
output_text += f"前15句预测的类别索引: {id_2_lable[predicted_class_idx_qian]}\n"
output_text += f"后15句预测的类别索引: {id_2_lable[predicted_class_idx_hou]}\n"
output_text += f"中间句子预测的类别索引: {id_2_lable[predicted_class_idx_zhong]}\n"
output_text += f"真实的类别索引: {id_2_lable[lable_true]}\n"
output_text += "\n==============================================================================\n"
id_2_len = {}
for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]:
if i not in id_2_len:
id_2_len[i] = 1
else:
id_2_len[i] += 1
queding = False
predicted_class_idx = ""
for i in id_2_len:
if id_2_len[i] >= 2:
queding = True
predicted_class_idx = i
if queding == True and predicted_class_idx == lable_true:
zhengque += 1
data_zong_true.append([paper, lable_true])
else:
# jian = "-true-" + id_2_lable[lable_true] + "-false-" + id_2_lable[predicted_class_idx]
# if jian not in tongji_false:
# tongji_false[jian] = 1
# else:
# tongji_false[jian] += 1
output_text_false += paper
output_text_false += "\n"
output_text_false += f"前15句预测的类别索引: {id_2_lable[predicted_class_idx_qian]}\n"
output_text_false += f"后15句预测的类别索引: {id_2_lable[predicted_class_idx_hou]}\n"
output_text_false += f"中间句子预测的类别索引: {id_2_lable[predicted_class_idx_zhong]}\n"
output_text_false += f"真实的类别索引: {id_2_lable[lable_true]}\n"
output_text_false += "\n==============================================================================\n"
data_zong_false.append([paper])
'''
{"text": "Terrible customer service.", "label": ["negative"]}
{"text": "Really great transaction.", "label": ["positive"]}
{"text": "Great price.", "label": ["positive"]}
'''
# with open("data/data_zong_false.jsonl", "w", encoding="utf-8") as f:
# for i in data_zong_false:
# f.write(json.dumps({
# "text": i[0],
# "label": []
# }, ensure_ascii=False))
# f.write("\n")
#
# pd.DataFrame(data_zong_true, columns=["sentence", "label"]).to_csv("data/data_zong_true.csv", index=False, encoding="utf-8")
with open("data/data_zong_output_text_dev_roberta_zonghe.txt", "w", encoding="utf-8") as f:
f.write(output_text)
with open("data/data_zong_output_text_dev_roberta_zonghe_false.txt", "w", encoding="utf-8") as f:
f.write(output_text_false)
print(zhengque/zong_len)
print(tongji_false)