You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
69 lines
2.6 KiB
69 lines
2.6 KiB
![]()
2 years ago
|
# -*- coding: utf-8 -*-
|
||
|
"""
|
||
|
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
|
||
|
@description:
|
||
|
"""
|
||
|
import os
|
||
|
import json
|
||
|
import torch
|
||
|
from torch.utils.data import Dataset
|
||
|
from transformers import BertTokenizerFast
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
|
||
|
class DataCollator:
|
||
|
def __init__(self, tokenizer: BertTokenizerFast):
|
||
|
self.tokenizer = tokenizer
|
||
|
|
||
|
def __call__(self, data):
|
||
|
ori_texts, cor_texts, wrong_idss = zip(*data)
|
||
|
encoded_texts = [self.tokenizer(t, return_offsets_mapping=True, add_special_tokens=False) for t in ori_texts]
|
||
|
max_len = max([len(t['input_ids']) for t in encoded_texts]) + 2
|
||
|
det_labels = torch.zeros(len(ori_texts), max_len).long()
|
||
|
|
||
|
for i, (encoded_text, wrong_ids) in enumerate(zip(encoded_texts, wrong_idss)):
|
||
|
off_mapping = encoded_text['offset_mapping']
|
||
|
for idx in wrong_ids:
|
||
|
for j, (b, e) in enumerate(off_mapping):
|
||
|
if b <= idx < e:
|
||
|
# j+1是因为前面的 CLS token
|
||
|
det_labels[i, j + 1] = 1
|
||
|
break
|
||
|
|
||
|
return list(ori_texts), list(cor_texts), det_labels
|
||
|
|
||
|
|
||
|
class CscDataset(Dataset):
|
||
|
def __init__(self, file_path):
|
||
|
self.data = json.load(open(file_path, 'r', encoding='utf-8'))
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.data)
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
return self.data[index]['original_text'], self.data[index]['correct_text'], self.data[index]['wrong_ids']
|
||
|
|
||
|
|
||
|
def make_loaders(collate_fn, train_path='', valid_path='', test_path='',
|
||
|
batch_size=32, num_workers=4):
|
||
|
train_loader = None
|
||
|
if train_path and os.path.exists(train_path):
|
||
|
train_loader = DataLoader(CscDataset(train_path),
|
||
|
batch_size=batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=num_workers,
|
||
|
collate_fn=collate_fn)
|
||
|
valid_loader = None
|
||
|
if valid_path and os.path.exists(valid_path):
|
||
|
valid_loader = DataLoader(CscDataset(valid_path),
|
||
|
batch_size=batch_size,
|
||
|
num_workers=num_workers,
|
||
|
collate_fn=collate_fn)
|
||
|
test_loader = None
|
||
|
if test_path and os.path.exists(test_path):
|
||
|
test_loader = DataLoader(CscDataset(test_path),
|
||
|
batch_size=batch_size,
|
||
|
num_workers=num_workers,
|
||
|
collate_fn=collate_fn)
|
||
|
return train_loader, valid_loader, test_loader
|