You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
250 lines
7.6 KiB
250 lines
7.6 KiB
import json
|
|
import os
|
|
import re
|
|
|
|
# os.environ["WANDB_DISABLED"] = "true"
|
|
|
|
# 设置CUDA设备
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
|
|
|
|
|
|
import logging
|
|
import os
|
|
import random
|
|
import sys
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
|
|
import datasets
|
|
import evaluate
|
|
import numpy as np
|
|
from datasets import load_dataset
|
|
from tqdm import tqdm
|
|
import transformers
|
|
from transformers import (
|
|
AutoConfig,
|
|
AutoModelForSequenceClassification,
|
|
AutoTokenizer,
|
|
DataCollatorWithPadding,
|
|
EvalPrediction,
|
|
HfArgumentParser,
|
|
PretrainedConfig,
|
|
Trainer,
|
|
TrainingArguments,
|
|
default_data_collator,
|
|
set_seed,
|
|
BertTokenizer,
|
|
BertModel,
|
|
BigBirdTokenizer,
|
|
BigBirdForSequenceClassification
|
|
)
|
|
import torch
|
|
# 检查GPU是否可用
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f"使用设备: {device}")
|
|
|
|
|
|
lable_2_id_fenji = {
|
|
"标题": 0,
|
|
"正文": 1,
|
|
"无用类别": 2
|
|
}
|
|
|
|
# lable_2_id_fenji = {
|
|
# "标题+正文": 0,
|
|
# "无用类别": 1
|
|
# }
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
"data_zong_shout_3",
|
|
use_fast=True,
|
|
revision="main",
|
|
trust_remote_code=False,
|
|
)
|
|
|
|
model_name = "data_zong_shout_3"
|
|
config = AutoConfig.from_pretrained(
|
|
model_name,
|
|
num_labels=len(lable_2_id_fenji),
|
|
revision="main",
|
|
trust_remote_code=False
|
|
)
|
|
model_roberta = AutoModelForSequenceClassification.from_pretrained(
|
|
model_name,
|
|
config=config,
|
|
revision="main",
|
|
trust_remote_code=False,
|
|
ignore_mismatched_sizes=False,
|
|
).to(device)
|
|
|
|
model_name = "data_zong_no_start_shout_3"
|
|
config = AutoConfig.from_pretrained(
|
|
model_name,
|
|
num_labels=len(lable_2_id_fenji),
|
|
revision="main",
|
|
trust_remote_code=False
|
|
)
|
|
model_roberta_no_start = AutoModelForSequenceClassification.from_pretrained(
|
|
model_name,
|
|
config=config,
|
|
revision="main",
|
|
trust_remote_code=False,
|
|
ignore_mismatched_sizes=False,
|
|
).to(device)
|
|
|
|
model_name = "data_zong_no_end_shout_3"
|
|
config = AutoConfig.from_pretrained(
|
|
model_name,
|
|
num_labels=len(lable_2_id_fenji),
|
|
revision="main",
|
|
trust_remote_code=False
|
|
)
|
|
model_roberta_no_end = AutoModelForSequenceClassification.from_pretrained(
|
|
model_name,
|
|
config=config,
|
|
revision="main",
|
|
trust_remote_code=False,
|
|
ignore_mismatched_sizes=False,
|
|
).to(device)
|
|
|
|
|
|
# lable_2_id_fenji = {
|
|
# "标题": 0,
|
|
# "正文": 1,
|
|
# "无用类别": 2
|
|
# }
|
|
|
|
|
|
id_2_lable = {}
|
|
for i in lable_2_id_fenji:
|
|
if lable_2_id_fenji[i] not in id_2_lable:
|
|
id_2_lable[lable_2_id_fenji[i]] = i
|
|
|
|
import pandas as pd
|
|
import torch
|
|
|
|
data_zong_1 = pd.read_csv("data/data_zong_dev_2.csv").values.tolist()
|
|
# data_zong_2 = pd.read_csv("data/data_zong_train.csv").values.tolist()
|
|
# data_zong = data_zong_1 + data_zong_2
|
|
|
|
data_zong = data_zong_1
|
|
output_text = ""
|
|
output_text_false = ""
|
|
zong_len = 0
|
|
zhengque = 0
|
|
|
|
tongji_false = {}
|
|
|
|
|
|
data_zong_true = []
|
|
data_zong_false = []
|
|
for i in tqdm(range(len(data_zong))):
|
|
zong_len += 1
|
|
paper = data_zong[i][0]
|
|
lable_true = data_zong[i][1]
|
|
paper_new_start_1, paper_new_start_2 = paper.split("<Start>")
|
|
paper_new_end_1, paper_new_end_2 = paper.split("<End>")
|
|
|
|
# 视野前后7句
|
|
paper_new_start_1_jiequ = "\n".join(paper_new_start_1.strip("\n").split("\n")[-7:])
|
|
paper_new_end_2_jiequ = "\n".join(paper_new_end_2.strip("\n").split("\n")[:7])
|
|
paper_object = "<Start>" + paper_new_start_2.split("\n")[0]
|
|
paper_zhong = "\n".join([paper_new_start_1_jiequ, paper_object, paper_new_end_2_jiequ])
|
|
|
|
# 视野前15句
|
|
paper_new_start_1_jiequ = "\n".join(paper_new_start_1.strip("\n").split("\n")[-15:])
|
|
paper_qian = "\n".join([paper_new_start_1_jiequ, paper_object])
|
|
|
|
# 视野后15句
|
|
paper_new_end_2_jiequ = "\n".join(paper_new_end_2.strip("\n").split("\n")[:15])
|
|
paper_hou = "\n".join([paper_object, paper_new_end_2_jiequ])
|
|
|
|
# 目标句子在中间预测结果
|
|
sentence_list = [paper_zhong]
|
|
# sentence_list = [data[1][0]]
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
|
|
result_on_device = {key: value.to(device) for key, value in result.items()}
|
|
logits = model_roberta(**result_on_device)
|
|
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item()
|
|
|
|
#
|
|
sentence_list = [paper_qian]
|
|
# sentence_list = [data[1][0]]
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
|
|
result_on_device = {key: value.to(device) for key, value in result.items()}
|
|
logits = model_roberta_no_end(**result_on_device)
|
|
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item()
|
|
sentence_list = [paper_hou]
|
|
# sentence_list = [data[1][0]]
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
|
|
result_on_device = {key: value.to(device) for key, value in result.items()}
|
|
logits = model_roberta_no_start(**result_on_device)
|
|
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item()
|
|
|
|
output_text += paper
|
|
output_text += "\n"
|
|
output_text += f"前15句预测的类别索引: {id_2_lable[predicted_class_idx_qian]}\n"
|
|
output_text += f"后15句预测的类别索引: {id_2_lable[predicted_class_idx_hou]}\n"
|
|
output_text += f"中间句子预测的类别索引: {id_2_lable[predicted_class_idx_zhong]}\n"
|
|
output_text += f"真实的类别索引: {id_2_lable[lable_true]}\n"
|
|
output_text += "\n==============================================================================\n"
|
|
|
|
id_2_len = {}
|
|
for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]:
|
|
if i not in id_2_len:
|
|
id_2_len[i] = 1
|
|
else:
|
|
id_2_len[i] += 1
|
|
|
|
queding = False
|
|
predicted_class_idx = ""
|
|
for i in id_2_len:
|
|
if id_2_len[i] >= 2:
|
|
queding = True
|
|
predicted_class_idx = i
|
|
|
|
if queding == True and predicted_class_idx == lable_true:
|
|
zhengque += 1
|
|
data_zong_true.append([paper, lable_true])
|
|
|
|
else:
|
|
# jian = "-true-" + id_2_lable[lable_true] + "-false-" + id_2_lable[predicted_class_idx]
|
|
# if jian not in tongji_false:
|
|
# tongji_false[jian] = 1
|
|
# else:
|
|
# tongji_false[jian] += 1
|
|
output_text_false += paper
|
|
output_text_false += "\n"
|
|
output_text_false += f"前15句预测的类别索引: {id_2_lable[predicted_class_idx_qian]}\n"
|
|
output_text_false += f"后15句预测的类别索引: {id_2_lable[predicted_class_idx_hou]}\n"
|
|
output_text_false += f"中间句子预测的类别索引: {id_2_lable[predicted_class_idx_zhong]}\n"
|
|
output_text_false += f"真实的类别索引: {id_2_lable[lable_true]}\n"
|
|
output_text_false += "\n==============================================================================\n"
|
|
data_zong_false.append([paper])
|
|
|
|
'''
|
|
{"text": "Terrible customer service.", "label": ["negative"]}
|
|
{"text": "Really great transaction.", "label": ["positive"]}
|
|
{"text": "Great price.", "label": ["positive"]}
|
|
'''
|
|
# with open("data/data_zong_false.jsonl", "w", encoding="utf-8") as f:
|
|
# for i in data_zong_false:
|
|
# f.write(json.dumps({
|
|
# "text": i[0],
|
|
# "label": []
|
|
# }, ensure_ascii=False))
|
|
# f.write("\n")
|
|
#
|
|
# pd.DataFrame(data_zong_true, columns=["sentence", "label"]).to_csv("data/data_zong_true.csv", index=False, encoding="utf-8")
|
|
|
|
with open("data/data_zong_output_text_dev_roberta_zonghe.txt", "w", encoding="utf-8") as f:
|
|
f.write(output_text)
|
|
|
|
with open("data/data_zong_output_text_dev_roberta_zonghe_false.txt", "w", encoding="utf-8") as f:
|
|
f.write(output_text_false)
|
|
|
|
|
|
|
|
print(zhengque/zong_len)
|
|
print(tongji_false)
|