From 17b4fc802e0556855fa54ff888d06156175dbb85 Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Thu, 13 Nov 2025 17:56:27 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E5=8F=98=E4=BD=8D=E7=BD=AE=E7=BC=96?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run_glue.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++---------- run_train.sh | 14 +++++++------- 2 files changed, 53 insertions(+), 17 deletions(-) diff --git a/run_glue.py b/run_glue.py index 3165d11..ee55318 100644 --- a/run_glue.py +++ b/run_glue.py @@ -22,7 +22,7 @@ import re os.environ["WANDB_DISABLED"] = "true" # 设置CUDA设备 -os.environ['CUDA_VISIBLE_DEVICES'] = '2' +os.environ['CUDA_VISIBLE_DEVICES'] = '-1' import logging @@ -307,7 +307,6 @@ def main(): # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_glue", model_args, data_args) print("model_args", model_args) - 9/0 # Setup logging logging.basicConfig( @@ -455,7 +454,7 @@ def main(): token=model_args.token, trust_remote_code=model_args.trust_remote_code, ) - tokenizer = BertTokenizer.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, @@ -463,7 +462,7 @@ def main(): token=model_args.token, trust_remote_code=model_args.trust_remote_code, ) - model = BertForSequenceClassification.from_pretrained( + model = AutoModelForSequenceClassification.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), config=config, @@ -530,20 +529,57 @@ def main(): max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) def preprocess_function(examples): - print(examples) - print("1") - # print("examples[sentence1_key]", examples[sentence1_key]) - # print("len(examples[sentence1_key])", len(examples[sentence1_key])) - # print("padding", padding) + print(1) # Tokenize the texts args = ( (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) ) result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) - # result = tokenizer_ulit(tokenizer, examples[sentence1_key], padding, max_seq_length) + # Map labels to IDs (not necessary for GLUE tasks) if label_to_id is not None and "label" in examples: result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] + + # 定义要查找的token序列对 + sequence_pairs = [ + ([133, 10906, 135], [133, 9931, 135]) # 第一个序列和第二个序列 + ] + + # 处理每个样本 + for i in range(len(result['input_ids'])): + input_ids = result['input_ids'][i] + token_type_ids = result['token_type_ids'][i] + + # 对每对序列进行处理 + for seq1, seq2 in sequence_pairs: + seq1_len = len(seq1) + seq2_len = len(seq2) + + # 查找第一个序列 + seq1_positions = [] + for j in range(len(input_ids) - seq1_len + 1): + if input_ids[j:j + seq1_len] == seq1: + seq1_positions.append(j) + + # 查找第二个序列 + seq2_positions = [] + for j in range(len(input_ids) - seq2_len + 1): + if input_ids[j:j + seq2_len] == seq2: + seq2_positions.append(j) + + # 处理每对找到的序列 + for pos1 in seq1_positions: + for pos2 in seq2_positions: + if pos1 < pos2: # 确保第一个序列在第二个序列之前 + # 设置两个序列之间(包括序列本身)的token_type_ids为1 + start_idx = pos1 + end_idx = pos2 + seq2_len + + for k in range(start_idx, end_idx): + token_type_ids[k] = 1 + + # 更新处理后的token_type_ids + result['token_type_ids'][i] = token_type_ids return result with training_args.main_process_first(desc="dataset map pre-processing"): diff --git a/run_train.sh b/run_train.sh index febd7bd..edebfc1 100644 --- a/run_train.sh +++ b/run_train.sh @@ -1,11 +1,11 @@ python run_glue.py \ - --model_name_or_path chinese_bert_wwm_ext_pytorch \ - --train_file data/train_data_weipu.csv \ - --validation_file data/dev_data_weipu.csv \ + --model_name_or_path /home/majiahui/project/models-llm/longformer-chinese-base-4096 \ + --train_file data/long_paper_train_3_1.csv \ + --validation_file data/long_paper_dev_3_1.csv \ --do_train \ --do_eval \ - --max_seq_length 512 \ - --per_device_train_batch_size 32 \ + --max_seq_length 4096 \ + --per_device_train_batch_size 4 \ --learning_rate 2e-5 \ - --num_train_epochs 5 \ - --output_dir aigc_check \ No newline at end of file + --num_train_epochs 1 \ + --output_dir long_paper_ceshi \ No newline at end of file