1 changed files with 250 additions and 0 deletions
@ -0,0 +1,250 @@ |
|||||
|
import json |
||||
|
import os |
||||
|
import re |
||||
|
|
||||
|
# os.environ["WANDB_DISABLED"] = "true" |
||||
|
|
||||
|
# 设置CUDA设备 |
||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '2' |
||||
|
|
||||
|
|
||||
|
import logging |
||||
|
import os |
||||
|
import random |
||||
|
import sys |
||||
|
from dataclasses import dataclass, field |
||||
|
from typing import Optional |
||||
|
|
||||
|
import datasets |
||||
|
import evaluate |
||||
|
import numpy as np |
||||
|
from datasets import load_dataset |
||||
|
from tqdm import tqdm |
||||
|
import transformers |
||||
|
from transformers import ( |
||||
|
AutoConfig, |
||||
|
AutoModelForSequenceClassification, |
||||
|
AutoTokenizer, |
||||
|
DataCollatorWithPadding, |
||||
|
EvalPrediction, |
||||
|
HfArgumentParser, |
||||
|
PretrainedConfig, |
||||
|
Trainer, |
||||
|
TrainingArguments, |
||||
|
default_data_collator, |
||||
|
set_seed, |
||||
|
BertTokenizer, |
||||
|
BertModel, |
||||
|
BigBirdTokenizer, |
||||
|
BigBirdForSequenceClassification |
||||
|
) |
||||
|
import torch |
||||
|
# 检查GPU是否可用 |
||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
||||
|
print(f"使用设备: {device}") |
||||
|
|
||||
|
|
||||
|
lable_2_id_fenji = { |
||||
|
"标题": 0, |
||||
|
"正文": 1, |
||||
|
"无用类别": 2 |
||||
|
} |
||||
|
|
||||
|
# lable_2_id_fenji = { |
||||
|
# "标题+正文": 0, |
||||
|
# "无用类别": 1 |
||||
|
# } |
||||
|
|
||||
|
|
||||
|
tokenizer = AutoTokenizer.from_pretrained( |
||||
|
"data_zong_shout_3", |
||||
|
use_fast=True, |
||||
|
revision="main", |
||||
|
trust_remote_code=False, |
||||
|
) |
||||
|
|
||||
|
model_name = "data_zong_shout_3" |
||||
|
config = AutoConfig.from_pretrained( |
||||
|
model_name, |
||||
|
num_labels=len(lable_2_id_fenji), |
||||
|
revision="main", |
||||
|
trust_remote_code=False |
||||
|
) |
||||
|
model_roberta = AutoModelForSequenceClassification.from_pretrained( |
||||
|
model_name, |
||||
|
config=config, |
||||
|
revision="main", |
||||
|
trust_remote_code=False, |
||||
|
ignore_mismatched_sizes=False, |
||||
|
).to(device) |
||||
|
|
||||
|
model_name = "data_zong_no_start_shout_3" |
||||
|
config = AutoConfig.from_pretrained( |
||||
|
model_name, |
||||
|
num_labels=len(lable_2_id_fenji), |
||||
|
revision="main", |
||||
|
trust_remote_code=False |
||||
|
) |
||||
|
model_roberta_no_start = AutoModelForSequenceClassification.from_pretrained( |
||||
|
model_name, |
||||
|
config=config, |
||||
|
revision="main", |
||||
|
trust_remote_code=False, |
||||
|
ignore_mismatched_sizes=False, |
||||
|
).to(device) |
||||
|
|
||||
|
model_name = "data_zong_no_end_shout_3" |
||||
|
config = AutoConfig.from_pretrained( |
||||
|
model_name, |
||||
|
num_labels=len(lable_2_id_fenji), |
||||
|
revision="main", |
||||
|
trust_remote_code=False |
||||
|
) |
||||
|
model_roberta_no_end = AutoModelForSequenceClassification.from_pretrained( |
||||
|
model_name, |
||||
|
config=config, |
||||
|
revision="main", |
||||
|
trust_remote_code=False, |
||||
|
ignore_mismatched_sizes=False, |
||||
|
).to(device) |
||||
|
|
||||
|
|
||||
|
# lable_2_id_fenji = { |
||||
|
# "标题": 0, |
||||
|
# "正文": 1, |
||||
|
# "无用类别": 2 |
||||
|
# } |
||||
|
|
||||
|
|
||||
|
id_2_lable = {} |
||||
|
for i in lable_2_id_fenji: |
||||
|
if lable_2_id_fenji[i] not in id_2_lable: |
||||
|
id_2_lable[lable_2_id_fenji[i]] = i |
||||
|
|
||||
|
import pandas as pd |
||||
|
import torch |
||||
|
|
||||
|
data_zong_1 = pd.read_csv("data/data_zong_dev_2.csv").values.tolist() |
||||
|
# data_zong_2 = pd.read_csv("data/data_zong_train.csv").values.tolist() |
||||
|
# data_zong = data_zong_1 + data_zong_2 |
||||
|
|
||||
|
data_zong = data_zong_1 |
||||
|
output_text = "" |
||||
|
output_text_false = "" |
||||
|
zong_len = 0 |
||||
|
zhengque = 0 |
||||
|
|
||||
|
tongji_false = {} |
||||
|
|
||||
|
|
||||
|
data_zong_true = [] |
||||
|
data_zong_false = [] |
||||
|
for i in tqdm(range(len(data_zong))): |
||||
|
zong_len += 1 |
||||
|
paper = data_zong[i][0] |
||||
|
lable_true = data_zong[i][1] |
||||
|
paper_new_start_1, paper_new_start_2 = paper.split("<Start>") |
||||
|
paper_new_end_1, paper_new_end_2 = paper.split("<End>") |
||||
|
|
||||
|
# 视野前后7句 |
||||
|
paper_new_start_1_jiequ = "\n".join(paper_new_start_1.strip("\n").split("\n")[-7:]) |
||||
|
paper_new_end_2_jiequ = "\n".join(paper_new_end_2.strip("\n").split("\n")[:7]) |
||||
|
paper_object = "<Start>" + paper_new_start_2.split("\n")[0] |
||||
|
paper_zhong = "\n".join([paper_new_start_1_jiequ, paper_object, paper_new_end_2_jiequ]) |
||||
|
|
||||
|
# 视野前15句 |
||||
|
paper_new_start_1_jiequ = "\n".join(paper_new_start_1.strip("\n").split("\n")[-15:]) |
||||
|
paper_qian = "\n".join([paper_new_start_1_jiequ, paper_object]) |
||||
|
|
||||
|
# 视野后15句 |
||||
|
paper_new_end_2_jiequ = "\n".join(paper_new_end_2.strip("\n").split("\n")[:15]) |
||||
|
paper_hou = "\n".join([paper_object, paper_new_end_2_jiequ]) |
||||
|
|
||||
|
# 目标句子在中间预测结果 |
||||
|
sentence_list = [paper_zhong] |
||||
|
# sentence_list = [data[1][0]] |
||||
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
||||
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
||||
|
logits = model_roberta(**result_on_device) |
||||
|
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() |
||||
|
|
||||
|
# |
||||
|
sentence_list = [paper_qian] |
||||
|
# sentence_list = [data[1][0]] |
||||
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
||||
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
||||
|
logits = model_roberta_no_end(**result_on_device) |
||||
|
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() |
||||
|
sentence_list = [paper_hou] |
||||
|
# sentence_list = [data[1][0]] |
||||
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
||||
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
||||
|
logits = model_roberta_no_start(**result_on_device) |
||||
|
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() |
||||
|
|
||||
|
output_text += paper |
||||
|
output_text += "\n" |
||||
|
output_text += f"前15句预测的类别索引: {id_2_lable[predicted_class_idx_qian]}\n" |
||||
|
output_text += f"后15句预测的类别索引: {id_2_lable[predicted_class_idx_hou]}\n" |
||||
|
output_text += f"中间句子预测的类别索引: {id_2_lable[predicted_class_idx_zhong]}\n" |
||||
|
output_text += f"真实的类别索引: {id_2_lable[lable_true]}\n" |
||||
|
output_text += "\n==============================================================================\n" |
||||
|
|
||||
|
id_2_len = {} |
||||
|
for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]: |
||||
|
if i not in id_2_len: |
||||
|
id_2_len[i] = 1 |
||||
|
else: |
||||
|
id_2_len[i] += 1 |
||||
|
|
||||
|
queding = False |
||||
|
predicted_class_idx = "" |
||||
|
for i in id_2_len: |
||||
|
if id_2_len[i] >= 2: |
||||
|
queding = True |
||||
|
predicted_class_idx = i |
||||
|
|
||||
|
if queding == True and predicted_class_idx == lable_true: |
||||
|
zhengque += 1 |
||||
|
data_zong_true.append([paper, lable_true]) |
||||
|
|
||||
|
else: |
||||
|
# jian = "-true-" + id_2_lable[lable_true] + "-false-" + id_2_lable[predicted_class_idx] |
||||
|
# if jian not in tongji_false: |
||||
|
# tongji_false[jian] = 1 |
||||
|
# else: |
||||
|
# tongji_false[jian] += 1 |
||||
|
output_text_false += paper |
||||
|
output_text_false += "\n" |
||||
|
output_text_false += f"前15句预测的类别索引: {id_2_lable[predicted_class_idx_qian]}\n" |
||||
|
output_text_false += f"后15句预测的类别索引: {id_2_lable[predicted_class_idx_hou]}\n" |
||||
|
output_text_false += f"中间句子预测的类别索引: {id_2_lable[predicted_class_idx_zhong]}\n" |
||||
|
output_text_false += f"真实的类别索引: {id_2_lable[lable_true]}\n" |
||||
|
output_text_false += "\n==============================================================================\n" |
||||
|
data_zong_false.append([paper]) |
||||
|
|
||||
|
''' |
||||
|
{"text": "Terrible customer service.", "label": ["negative"]} |
||||
|
{"text": "Really great transaction.", "label": ["positive"]} |
||||
|
{"text": "Great price.", "label": ["positive"]} |
||||
|
''' |
||||
|
# with open("data/data_zong_false.jsonl", "w", encoding="utf-8") as f: |
||||
|
# for i in data_zong_false: |
||||
|
# f.write(json.dumps({ |
||||
|
# "text": i[0], |
||||
|
# "label": [] |
||||
|
# }, ensure_ascii=False)) |
||||
|
# f.write("\n") |
||||
|
# |
||||
|
# pd.DataFrame(data_zong_true, columns=["sentence", "label"]).to_csv("data/data_zong_true.csv", index=False, encoding="utf-8") |
||||
|
|
||||
|
with open("data/data_zong_output_text_dev_roberta_zonghe.txt", "w", encoding="utf-8") as f: |
||||
|
f.write(output_text) |
||||
|
|
||||
|
with open("data/data_zong_output_text_dev_roberta_zonghe_false.txt", "w", encoding="utf-8") as f: |
||||
|
f.write(output_text_false) |
||||
|
|
||||
|
|
||||
|
|
||||
|
print(zhengque/zong_len) |
||||
|
print(tongji_false) |
||||
Loading…
Reference in new issue