1 changed files with 250 additions and 0 deletions
@ -0,0 +1,250 @@ |
|||
import json |
|||
import os |
|||
import re |
|||
|
|||
# os.environ["WANDB_DISABLED"] = "true" |
|||
|
|||
# 设置CUDA设备 |
|||
os.environ['CUDA_VISIBLE_DEVICES'] = '2' |
|||
|
|||
|
|||
import logging |
|||
import os |
|||
import random |
|||
import sys |
|||
from dataclasses import dataclass, field |
|||
from typing import Optional |
|||
|
|||
import datasets |
|||
import evaluate |
|||
import numpy as np |
|||
from datasets import load_dataset |
|||
from tqdm import tqdm |
|||
import transformers |
|||
from transformers import ( |
|||
AutoConfig, |
|||
AutoModelForSequenceClassification, |
|||
AutoTokenizer, |
|||
DataCollatorWithPadding, |
|||
EvalPrediction, |
|||
HfArgumentParser, |
|||
PretrainedConfig, |
|||
Trainer, |
|||
TrainingArguments, |
|||
default_data_collator, |
|||
set_seed, |
|||
BertTokenizer, |
|||
BertModel, |
|||
BigBirdTokenizer, |
|||
BigBirdForSequenceClassification |
|||
) |
|||
import torch |
|||
# 检查GPU是否可用 |
|||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|||
print(f"使用设备: {device}") |
|||
|
|||
|
|||
lable_2_id_fenji = { |
|||
"标题": 0, |
|||
"正文": 1, |
|||
"无用类别": 2 |
|||
} |
|||
|
|||
# lable_2_id_fenji = { |
|||
# "标题+正文": 0, |
|||
# "无用类别": 1 |
|||
# } |
|||
|
|||
|
|||
tokenizer = AutoTokenizer.from_pretrained( |
|||
"data_zong_shout_3", |
|||
use_fast=True, |
|||
revision="main", |
|||
trust_remote_code=False, |
|||
) |
|||
|
|||
model_name = "data_zong_shout_3" |
|||
config = AutoConfig.from_pretrained( |
|||
model_name, |
|||
num_labels=len(lable_2_id_fenji), |
|||
revision="main", |
|||
trust_remote_code=False |
|||
) |
|||
model_roberta = AutoModelForSequenceClassification.from_pretrained( |
|||
model_name, |
|||
config=config, |
|||
revision="main", |
|||
trust_remote_code=False, |
|||
ignore_mismatched_sizes=False, |
|||
).to(device) |
|||
|
|||
model_name = "data_zong_no_start_shout_3" |
|||
config = AutoConfig.from_pretrained( |
|||
model_name, |
|||
num_labels=len(lable_2_id_fenji), |
|||
revision="main", |
|||
trust_remote_code=False |
|||
) |
|||
model_roberta_no_start = AutoModelForSequenceClassification.from_pretrained( |
|||
model_name, |
|||
config=config, |
|||
revision="main", |
|||
trust_remote_code=False, |
|||
ignore_mismatched_sizes=False, |
|||
).to(device) |
|||
|
|||
model_name = "data_zong_no_end_shout_3" |
|||
config = AutoConfig.from_pretrained( |
|||
model_name, |
|||
num_labels=len(lable_2_id_fenji), |
|||
revision="main", |
|||
trust_remote_code=False |
|||
) |
|||
model_roberta_no_end = AutoModelForSequenceClassification.from_pretrained( |
|||
model_name, |
|||
config=config, |
|||
revision="main", |
|||
trust_remote_code=False, |
|||
ignore_mismatched_sizes=False, |
|||
).to(device) |
|||
|
|||
|
|||
# lable_2_id_fenji = { |
|||
# "标题": 0, |
|||
# "正文": 1, |
|||
# "无用类别": 2 |
|||
# } |
|||
|
|||
|
|||
id_2_lable = {} |
|||
for i in lable_2_id_fenji: |
|||
if lable_2_id_fenji[i] not in id_2_lable: |
|||
id_2_lable[lable_2_id_fenji[i]] = i |
|||
|
|||
import pandas as pd |
|||
import torch |
|||
|
|||
data_zong_1 = pd.read_csv("data/data_zong_dev_2.csv").values.tolist() |
|||
# data_zong_2 = pd.read_csv("data/data_zong_train.csv").values.tolist() |
|||
# data_zong = data_zong_1 + data_zong_2 |
|||
|
|||
data_zong = data_zong_1 |
|||
output_text = "" |
|||
output_text_false = "" |
|||
zong_len = 0 |
|||
zhengque = 0 |
|||
|
|||
tongji_false = {} |
|||
|
|||
|
|||
data_zong_true = [] |
|||
data_zong_false = [] |
|||
for i in tqdm(range(len(data_zong))): |
|||
zong_len += 1 |
|||
paper = data_zong[i][0] |
|||
lable_true = data_zong[i][1] |
|||
paper_new_start_1, paper_new_start_2 = paper.split("<Start>") |
|||
paper_new_end_1, paper_new_end_2 = paper.split("<End>") |
|||
|
|||
# 视野前后7句 |
|||
paper_new_start_1_jiequ = "\n".join(paper_new_start_1.strip("\n").split("\n")[-7:]) |
|||
paper_new_end_2_jiequ = "\n".join(paper_new_end_2.strip("\n").split("\n")[:7]) |
|||
paper_object = "<Start>" + paper_new_start_2.split("\n")[0] |
|||
paper_zhong = "\n".join([paper_new_start_1_jiequ, paper_object, paper_new_end_2_jiequ]) |
|||
|
|||
# 视野前15句 |
|||
paper_new_start_1_jiequ = "\n".join(paper_new_start_1.strip("\n").split("\n")[-15:]) |
|||
paper_qian = "\n".join([paper_new_start_1_jiequ, paper_object]) |
|||
|
|||
# 视野后15句 |
|||
paper_new_end_2_jiequ = "\n".join(paper_new_end_2.strip("\n").split("\n")[:15]) |
|||
paper_hou = "\n".join([paper_object, paper_new_end_2_jiequ]) |
|||
|
|||
# 目标句子在中间预测结果 |
|||
sentence_list = [paper_zhong] |
|||
# sentence_list = [data[1][0]] |
|||
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|||
result_on_device = {key: value.to(device) for key, value in result.items()} |
|||
logits = model_roberta(**result_on_device) |
|||
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() |
|||
|
|||
# |
|||
sentence_list = [paper_qian] |
|||
# sentence_list = [data[1][0]] |
|||
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|||
result_on_device = {key: value.to(device) for key, value in result.items()} |
|||
logits = model_roberta_no_end(**result_on_device) |
|||
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() |
|||
sentence_list = [paper_hou] |
|||
# sentence_list = [data[1][0]] |
|||
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
|||
result_on_device = {key: value.to(device) for key, value in result.items()} |
|||
logits = model_roberta_no_start(**result_on_device) |
|||
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() |
|||
|
|||
output_text += paper |
|||
output_text += "\n" |
|||
output_text += f"前15句预测的类别索引: {id_2_lable[predicted_class_idx_qian]}\n" |
|||
output_text += f"后15句预测的类别索引: {id_2_lable[predicted_class_idx_hou]}\n" |
|||
output_text += f"中间句子预测的类别索引: {id_2_lable[predicted_class_idx_zhong]}\n" |
|||
output_text += f"真实的类别索引: {id_2_lable[lable_true]}\n" |
|||
output_text += "\n==============================================================================\n" |
|||
|
|||
id_2_len = {} |
|||
for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]: |
|||
if i not in id_2_len: |
|||
id_2_len[i] = 1 |
|||
else: |
|||
id_2_len[i] += 1 |
|||
|
|||
queding = False |
|||
predicted_class_idx = "" |
|||
for i in id_2_len: |
|||
if id_2_len[i] >= 2: |
|||
queding = True |
|||
predicted_class_idx = i |
|||
|
|||
if queding == True and predicted_class_idx == lable_true: |
|||
zhengque += 1 |
|||
data_zong_true.append([paper, lable_true]) |
|||
|
|||
else: |
|||
# jian = "-true-" + id_2_lable[lable_true] + "-false-" + id_2_lable[predicted_class_idx] |
|||
# if jian not in tongji_false: |
|||
# tongji_false[jian] = 1 |
|||
# else: |
|||
# tongji_false[jian] += 1 |
|||
output_text_false += paper |
|||
output_text_false += "\n" |
|||
output_text_false += f"前15句预测的类别索引: {id_2_lable[predicted_class_idx_qian]}\n" |
|||
output_text_false += f"后15句预测的类别索引: {id_2_lable[predicted_class_idx_hou]}\n" |
|||
output_text_false += f"中间句子预测的类别索引: {id_2_lable[predicted_class_idx_zhong]}\n" |
|||
output_text_false += f"真实的类别索引: {id_2_lable[lable_true]}\n" |
|||
output_text_false += "\n==============================================================================\n" |
|||
data_zong_false.append([paper]) |
|||
|
|||
''' |
|||
{"text": "Terrible customer service.", "label": ["negative"]} |
|||
{"text": "Really great transaction.", "label": ["positive"]} |
|||
{"text": "Great price.", "label": ["positive"]} |
|||
''' |
|||
# with open("data/data_zong_false.jsonl", "w", encoding="utf-8") as f: |
|||
# for i in data_zong_false: |
|||
# f.write(json.dumps({ |
|||
# "text": i[0], |
|||
# "label": [] |
|||
# }, ensure_ascii=False)) |
|||
# f.write("\n") |
|||
# |
|||
# pd.DataFrame(data_zong_true, columns=["sentence", "label"]).to_csv("data/data_zong_true.csv", index=False, encoding="utf-8") |
|||
|
|||
with open("data/data_zong_output_text_dev_roberta_zonghe.txt", "w", encoding="utf-8") as f: |
|||
f.write(output_text) |
|||
|
|||
with open("data/data_zong_output_text_dev_roberta_zonghe_false.txt", "w", encoding="utf-8") as f: |
|||
f.write(output_text_false) |
|||
|
|||
|
|||
|
|||
print(zhengque/zong_len) |
|||
print(tongji_false) |
|||
Loading…
Reference in new issue