From 735a7ec655416e92beefd608c44bb2befd759c7a Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Wed, 3 Dec 2025 16:20:59 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E8=AE=BA=E6=96=87=E4=B8=89?= =?UTF-8?q?=E5=88=86=E7=B1=BB=E7=AC=AC=E4=B8=80=E7=A7=8D=E5=88=86=E7=B1=BB?= =?UTF-8?q?=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- predict_long_text_roberta.py | 250 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 250 insertions(+) create mode 100644 predict_long_text_roberta.py diff --git a/predict_long_text_roberta.py b/predict_long_text_roberta.py new file mode 100644 index 0000000..fe3705e --- /dev/null +++ b/predict_long_text_roberta.py @@ -0,0 +1,250 @@ +import json +import os +import re + +# os.environ["WANDB_DISABLED"] = "true" + +# 设置CUDA设备 +os.environ['CUDA_VISIBLE_DEVICES'] = '2' + + +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from typing import Optional + +import datasets +import evaluate +import numpy as np +from datasets import load_dataset +from tqdm import tqdm +import transformers +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + HfArgumentParser, + PretrainedConfig, + Trainer, + TrainingArguments, + default_data_collator, + set_seed, + BertTokenizer, + BertModel, + BigBirdTokenizer, + BigBirdForSequenceClassification +) +import torch +# 检查GPU是否可用 +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f"使用设备: {device}") + + +lable_2_id_fenji = { + "标题": 0, + "正文": 1, + "无用类别": 2 +} + +# lable_2_id_fenji = { +# "标题+正文": 0, +# "无用类别": 1 +# } + + +tokenizer = AutoTokenizer.from_pretrained( + "data_zong_shout_3", + use_fast=True, + revision="main", + trust_remote_code=False, +) + +model_name = "data_zong_shout_3" +config = AutoConfig.from_pretrained( + model_name, + num_labels=len(lable_2_id_fenji), + revision="main", + trust_remote_code=False +) +model_roberta = AutoModelForSequenceClassification.from_pretrained( + model_name, + config=config, + revision="main", + trust_remote_code=False, + ignore_mismatched_sizes=False, +).to(device) + +model_name = "data_zong_no_start_shout_3" +config = AutoConfig.from_pretrained( + model_name, + num_labels=len(lable_2_id_fenji), + revision="main", + trust_remote_code=False +) +model_roberta_no_start = AutoModelForSequenceClassification.from_pretrained( + model_name, + config=config, + revision="main", + trust_remote_code=False, + ignore_mismatched_sizes=False, +).to(device) + +model_name = "data_zong_no_end_shout_3" +config = AutoConfig.from_pretrained( + model_name, + num_labels=len(lable_2_id_fenji), + revision="main", + trust_remote_code=False +) +model_roberta_no_end = AutoModelForSequenceClassification.from_pretrained( + model_name, + config=config, + revision="main", + trust_remote_code=False, + ignore_mismatched_sizes=False, +).to(device) + + +# lable_2_id_fenji = { +# "标题": 0, +# "正文": 1, +# "无用类别": 2 +# } + + +id_2_lable = {} +for i in lable_2_id_fenji: + if lable_2_id_fenji[i] not in id_2_lable: + id_2_lable[lable_2_id_fenji[i]] = i + +import pandas as pd +import torch + +data_zong_1 = pd.read_csv("data/data_zong_dev_2.csv").values.tolist() +# data_zong_2 = pd.read_csv("data/data_zong_train.csv").values.tolist() +# data_zong = data_zong_1 + data_zong_2 + +data_zong = data_zong_1 +output_text = "" +output_text_false = "" +zong_len = 0 +zhengque = 0 + +tongji_false = {} + + +data_zong_true = [] +data_zong_false = [] +for i in tqdm(range(len(data_zong))): + zong_len += 1 + paper = data_zong[i][0] + lable_true = data_zong[i][1] + paper_new_start_1, paper_new_start_2 = paper.split("") + paper_new_end_1, paper_new_end_2 = paper.split("") + + # 视野前后7句 + paper_new_start_1_jiequ = "\n".join(paper_new_start_1.strip("\n").split("\n")[-7:]) + paper_new_end_2_jiequ = "\n".join(paper_new_end_2.strip("\n").split("\n")[:7]) + paper_object = "" + paper_new_start_2.split("\n")[0] + paper_zhong = "\n".join([paper_new_start_1_jiequ, paper_object, paper_new_end_2_jiequ]) + + # 视野前15句 + paper_new_start_1_jiequ = "\n".join(paper_new_start_1.strip("\n").split("\n")[-15:]) + paper_qian = "\n".join([paper_new_start_1_jiequ, paper_object]) + + # 视野后15句 + paper_new_end_2_jiequ = "\n".join(paper_new_end_2.strip("\n").split("\n")[:15]) + paper_hou = "\n".join([paper_object, paper_new_end_2_jiequ]) + + # 目标句子在中间预测结果 + sentence_list = [paper_zhong] + # sentence_list = [data[1][0]] + result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") + result_on_device = {key: value.to(device) for key, value in result.items()} + logits = model_roberta(**result_on_device) + predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() + + # + sentence_list = [paper_qian] + # sentence_list = [data[1][0]] + result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") + result_on_device = {key: value.to(device) for key, value in result.items()} + logits = model_roberta_no_end(**result_on_device) + predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() + sentence_list = [paper_hou] + # sentence_list = [data[1][0]] + result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") + result_on_device = {key: value.to(device) for key, value in result.items()} + logits = model_roberta_no_start(**result_on_device) + predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() + + output_text += paper + output_text += "\n" + output_text += f"前15句预测的类别索引: {id_2_lable[predicted_class_idx_qian]}\n" + output_text += f"后15句预测的类别索引: {id_2_lable[predicted_class_idx_hou]}\n" + output_text += f"中间句子预测的类别索引: {id_2_lable[predicted_class_idx_zhong]}\n" + output_text += f"真实的类别索引: {id_2_lable[lable_true]}\n" + output_text += "\n==============================================================================\n" + + id_2_len = {} + for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]: + if i not in id_2_len: + id_2_len[i] = 1 + else: + id_2_len[i] += 1 + + queding = False + predicted_class_idx = "" + for i in id_2_len: + if id_2_len[i] >= 2: + queding = True + predicted_class_idx = i + + if queding == True and predicted_class_idx == lable_true: + zhengque += 1 + data_zong_true.append([paper, lable_true]) + + else: + # jian = "-true-" + id_2_lable[lable_true] + "-false-" + id_2_lable[predicted_class_idx] + # if jian not in tongji_false: + # tongji_false[jian] = 1 + # else: + # tongji_false[jian] += 1 + output_text_false += paper + output_text_false += "\n" + output_text_false += f"前15句预测的类别索引: {id_2_lable[predicted_class_idx_qian]}\n" + output_text_false += f"后15句预测的类别索引: {id_2_lable[predicted_class_idx_hou]}\n" + output_text_false += f"中间句子预测的类别索引: {id_2_lable[predicted_class_idx_zhong]}\n" + output_text_false += f"真实的类别索引: {id_2_lable[lable_true]}\n" + output_text_false += "\n==============================================================================\n" + data_zong_false.append([paper]) + +''' +{"text": "Terrible customer service.", "label": ["negative"]} +{"text": "Really great transaction.", "label": ["positive"]} +{"text": "Great price.", "label": ["positive"]} +''' +# with open("data/data_zong_false.jsonl", "w", encoding="utf-8") as f: +# for i in data_zong_false: +# f.write(json.dumps({ +# "text": i[0], +# "label": [] +# }, ensure_ascii=False)) +# f.write("\n") +# +# pd.DataFrame(data_zong_true, columns=["sentence", "label"]).to_csv("data/data_zong_true.csv", index=False, encoding="utf-8") + +with open("data/data_zong_output_text_dev_roberta_zonghe.txt", "w", encoding="utf-8") as f: + f.write(output_text) + +with open("data/data_zong_output_text_dev_roberta_zonghe_false.txt", "w", encoding="utf-8") as f: + f.write(output_text_false) + + + +print(zhengque/zong_len) +print(tongji_false) \ No newline at end of file