纠错任务
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
2.6 KiB

2 years ago
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
@description:
"""
import os
import json
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizerFast
from torch.utils.data import DataLoader
class DataCollator:
def __init__(self, tokenizer: BertTokenizerFast):
self.tokenizer = tokenizer
def __call__(self, data):
ori_texts, cor_texts, wrong_idss = zip(*data)
encoded_texts = [self.tokenizer(t, return_offsets_mapping=True, add_special_tokens=False) for t in ori_texts]
max_len = max([len(t['input_ids']) for t in encoded_texts]) + 2
det_labels = torch.zeros(len(ori_texts), max_len).long()
for i, (encoded_text, wrong_ids) in enumerate(zip(encoded_texts, wrong_idss)):
off_mapping = encoded_text['offset_mapping']
for idx in wrong_ids:
for j, (b, e) in enumerate(off_mapping):
if b <= idx < e:
# j+1是因为前面的 CLS token
det_labels[i, j + 1] = 1
break
return list(ori_texts), list(cor_texts), det_labels
class CscDataset(Dataset):
def __init__(self, file_path):
self.data = json.load(open(file_path, 'r', encoding='utf-8'))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]['original_text'], self.data[index]['correct_text'], self.data[index]['wrong_ids']
def make_loaders(collate_fn, train_path='', valid_path='', test_path='',
batch_size=32, num_workers=4):
train_loader = None
if train_path and os.path.exists(train_path):
train_loader = DataLoader(CscDataset(train_path),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn)
valid_loader = None
if valid_path and os.path.exists(valid_path):
valid_loader = DataLoader(CscDataset(valid_path),
batch_size=batch_size,
num_workers=num_workers,
collate_fn=collate_fn)
test_loader = None
if test_path and os.path.exists(test_path):
test_loader = DataLoader(CscDataset(test_path),
batch_size=batch_size,
num_workers=num_workers,
collate_fn=collate_fn)
return train_loader, valid_loader, test_loader