import json import os import re # os.environ["WANDB_DISABLED"] = "true" # 设置CUDA设备 os.environ['CUDA_VISIBLE_DEVICES'] = '2' import logging import os import random import sys from dataclasses import dataclass, field from typing import Optional import datasets import evaluate import numpy as np from datasets import load_dataset from tqdm import tqdm import transformers from transformers import ( AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, EvalPrediction, HfArgumentParser, PretrainedConfig, Trainer, TrainingArguments, default_data_collator, set_seed, BertTokenizer, BertModel, BigBirdTokenizer, BigBirdForSequenceClassification ) import torch # 检查GPU是否可用 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"使用设备: {device}") lable_2_id_fenji = { "标题": 0, "正文": 1, "无用类别": 2 } # lable_2_id_fenji = { # "标题+正文": 0, # "无用类别": 1 # } tokenizer = AutoTokenizer.from_pretrained( "data_zong_shout_3", use_fast=True, revision="main", trust_remote_code=False, ) model_name = "data_zong_shout_3" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_fenji), revision="main", trust_remote_code=False ) model_roberta = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", trust_remote_code=False, ignore_mismatched_sizes=False, ).to(device) model_name = "data_zong_no_start_shout_3" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_fenji), revision="main", trust_remote_code=False ) model_roberta_no_start = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", trust_remote_code=False, ignore_mismatched_sizes=False, ).to(device) model_name = "data_zong_no_end_shout_3" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_fenji), revision="main", trust_remote_code=False ) model_roberta_no_end = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", trust_remote_code=False, ignore_mismatched_sizes=False, ).to(device) # lable_2_id_fenji = { # "标题": 0, # "正文": 1, # "无用类别": 2 # } id_2_lable = {} for i in lable_2_id_fenji: if lable_2_id_fenji[i] not in id_2_lable: id_2_lable[lable_2_id_fenji[i]] = i import pandas as pd import torch data_zong_1 = pd.read_csv("data/data_zong_dev_2.csv").values.tolist() # data_zong_2 = pd.read_csv("data/data_zong_train.csv").values.tolist() # data_zong = data_zong_1 + data_zong_2 data_zong = data_zong_1 output_text = "" output_text_false = "" zong_len = 0 zhengque = 0 tongji_false = {} data_zong_true = [] data_zong_false = [] for i in tqdm(range(len(data_zong))): zong_len += 1 paper = data_zong[i][0] lable_true = data_zong[i][1] paper_new_start_1, paper_new_start_2 = paper.split("") paper_new_end_1, paper_new_end_2 = paper.split("") # 视野前后7句 paper_new_start_1_jiequ = "\n".join(paper_new_start_1.strip("\n").split("\n")[-7:]) paper_new_end_2_jiequ = "\n".join(paper_new_end_2.strip("\n").split("\n")[:7]) paper_object = "" + paper_new_start_2.split("\n")[0] paper_zhong = "\n".join([paper_new_start_1_jiequ, paper_object, paper_new_end_2_jiequ]) # 视野前15句 paper_new_start_1_jiequ = "\n".join(paper_new_start_1.strip("\n").split("\n")[-15:]) paper_qian = "\n".join([paper_new_start_1_jiequ, paper_object]) # 视野后15句 paper_new_end_2_jiequ = "\n".join(paper_new_end_2.strip("\n").split("\n")[:15]) paper_hou = "\n".join([paper_object, paper_new_end_2_jiequ]) # 目标句子在中间预测结果 sentence_list = [paper_zhong] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_roberta(**result_on_device) predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() # sentence_list = [paper_qian] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_roberta_no_end(**result_on_device) predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() sentence_list = [paper_hou] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_roberta_no_start(**result_on_device) predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() output_text += paper output_text += "\n" output_text += f"前15句预测的类别索引: {id_2_lable[predicted_class_idx_qian]}\n" output_text += f"后15句预测的类别索引: {id_2_lable[predicted_class_idx_hou]}\n" output_text += f"中间句子预测的类别索引: {id_2_lable[predicted_class_idx_zhong]}\n" output_text += f"真实的类别索引: {id_2_lable[lable_true]}\n" output_text += "\n==============================================================================\n" id_2_len = {} for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]: if i not in id_2_len: id_2_len[i] = 1 else: id_2_len[i] += 1 queding = False predicted_class_idx = "" for i in id_2_len: if id_2_len[i] >= 2: queding = True predicted_class_idx = i if queding == True and predicted_class_idx == lable_true: zhengque += 1 data_zong_true.append([paper, lable_true]) else: # jian = "-true-" + id_2_lable[lable_true] + "-false-" + id_2_lable[predicted_class_idx] # if jian not in tongji_false: # tongji_false[jian] = 1 # else: # tongji_false[jian] += 1 output_text_false += paper output_text_false += "\n" output_text_false += f"前15句预测的类别索引: {id_2_lable[predicted_class_idx_qian]}\n" output_text_false += f"后15句预测的类别索引: {id_2_lable[predicted_class_idx_hou]}\n" output_text_false += f"中间句子预测的类别索引: {id_2_lable[predicted_class_idx_zhong]}\n" output_text_false += f"真实的类别索引: {id_2_lable[lable_true]}\n" output_text_false += "\n==============================================================================\n" data_zong_false.append([paper]) ''' {"text": "Terrible customer service.", "label": ["negative"]} {"text": "Really great transaction.", "label": ["positive"]} {"text": "Great price.", "label": ["positive"]} ''' # with open("data/data_zong_false.jsonl", "w", encoding="utf-8") as f: # for i in data_zong_false: # f.write(json.dumps({ # "text": i[0], # "label": [] # }, ensure_ascii=False)) # f.write("\n") # # pd.DataFrame(data_zong_true, columns=["sentence", "label"]).to_csv("data/data_zong_true.csv", index=False, encoding="utf-8") with open("data/data_zong_output_text_dev_roberta_zonghe.txt", "w", encoding="utf-8") as f: f.write(output_text) with open("data/data_zong_output_text_dev_roberta_zonghe_false.txt", "w", encoding="utf-8") as f: f.write(output_text_false) print(zhengque/zong_len) print(tongji_false)