纠错任务
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

192 lines
7.1 KiB

2 years ago
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
@description:
"""
import operator
from abc import ABC
from loguru import logger
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from pycorrector.macbert import lr_scheduler
from pycorrector.macbert.evaluate_util import compute_corrector_prf, compute_sentence_level_prf
class FocalLoss(nn.Module):
"""
Softmax and sigmoid focal loss.
copy from https://github.com/lonePatient/TorchBlocks
"""
def __init__(self, num_labels, activation_type='softmax', gamma=2.0, alpha=0.25, epsilon=1.e-9):
super(FocalLoss, self).__init__()
self.num_labels = num_labels
self.gamma = gamma
self.alpha = alpha
self.epsilon = epsilon
self.activation_type = activation_type
def forward(self, input, target):
"""
Args:
logits: model's output, shape of [batch_size, num_cls]
target: ground truth labels, shape of [batch_size]
Returns:
shape of [batch_size]
"""
if self.activation_type == 'softmax':
idx = target.view(-1, 1).long()
one_hot_key = torch.zeros(idx.size(0), self.num_labels, dtype=torch.float32, device=idx.device)
one_hot_key = one_hot_key.scatter_(1, idx, 1)
logits = torch.softmax(input, dim=-1)
loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
loss = loss.sum(1)
elif self.activation_type == 'sigmoid':
multi_hot_key = target
logits = torch.sigmoid(input)
zero_hot_key = 1 - multi_hot_key
loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log()
return loss.mean()
def make_optimizer(cfg, model):
params = []
for key, value in model.named_parameters():
if not value.requires_grad:
continue
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
if "bias" in key:
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
else:
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
return optimizer
def build_lr_scheduler(cfg, optimizer):
scheduler_args = {
"optimizer": optimizer,
# warmup options
"warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
"warmup_epochs": cfg.SOLVER.WARMUP_EPOCHS,
"warmup_method": cfg.SOLVER.WARMUP_METHOD,
# multi-step lr scheduler options
"milestones": cfg.SOLVER.STEPS,
"gamma": cfg.SOLVER.GAMMA,
# cosine annealing lr scheduler options
"max_iters": cfg.SOLVER.MAX_ITER,
"delay_iters": cfg.SOLVER.DELAY_ITERS,
"eta_min_lr": cfg.SOLVER.ETA_MIN_LR,
}
scheduler = getattr(lr_scheduler, cfg.SOLVER.SCHED)(**scheduler_args)
return {'scheduler': scheduler, 'interval': cfg.SOLVER.INTERVAL}
class BaseTrainingEngine(pl.LightningModule):
def __init__(self, cfg, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cfg = cfg
def configure_optimizers(self):
optimizer = make_optimizer(self.cfg, self)
scheduler = build_lr_scheduler(self.cfg, optimizer)
return [optimizer], [scheduler]
def on_validation_epoch_start(self) -> None:
logger.info('Valid.')
def on_test_epoch_start(self) -> None:
logger.info('Testing...')
class CscTrainingModel(BaseTrainingEngine, ABC):
"""
用于CSC的BaseModel, 定义了训练及预测步骤
"""
def __init__(self, cfg, *args, **kwargs):
super().__init__(cfg, *args, **kwargs)
# loss weight
self.w = cfg.MODEL.HYPER_PARAMS[0]
def training_step(self, batch, batch_idx):
ori_text, cor_text, det_labels = batch
outputs = self.forward(ori_text, cor_text, det_labels)
loss = self.w * outputs[1] + (1 - self.w) * outputs[0]
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=len(ori_text))
return loss
def validation_step(self, batch, batch_idx):
ori_text, cor_text, det_labels = batch
outputs = self.forward(ori_text, cor_text, det_labels)
loss = self.w * outputs[1] + (1 - self.w) * outputs[0]
det_y_hat = (outputs[2] > 0.5).long()
cor_y_hat = torch.argmax((outputs[3]), dim=-1)
encoded_x = self.tokenizer(cor_text, padding=True, return_tensors='pt')
encoded_x.to(self._device)
cor_y = encoded_x['input_ids']
cor_y_hat *= encoded_x['attention_mask']
results = []
det_acc_labels = []
cor_acc_labels = []
for src, tgt, predict, det_predict, det_label in zip(ori_text, cor_y, cor_y_hat, det_y_hat, det_labels):
_src = self.tokenizer(src, add_special_tokens=False)['input_ids']
_tgt = tgt[1:len(_src) + 1].cpu().numpy().tolist()
_predict = predict[1:len(_src) + 1].cpu().numpy().tolist()
cor_acc_labels.append(1 if operator.eq(_tgt, _predict) else 0)
det_acc_labels.append(det_predict[1:len(_src) + 1].equal(det_label[1:len(_src) + 1]))
results.append((_src, _tgt, _predict,))
return loss.cpu().item(), det_acc_labels, cor_acc_labels, results
def validation_epoch_end(self, outputs) -> None:
det_acc_labels = []
cor_acc_labels = []
results = []
for out in outputs:
det_acc_labels += out[1]
cor_acc_labels += out[2]
results += out[3]
loss = np.mean([out[0] for out in outputs])
self.log('val_loss', loss)
logger.info(f'loss: {loss}')
logger.info(f'Detection:\n'
f'acc: {np.mean(det_acc_labels):.4f}')
logger.info(f'Correction:\n'
f'acc: {np.mean(cor_acc_labels):.4f}')
compute_corrector_prf(results, logger)
compute_sentence_level_prf(results, logger)
def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
def test_epoch_end(self, outputs) -> None:
logger.info('Test.')
self.validation_epoch_end(outputs)
def predict(self, texts):
inputs = self.tokenizer(texts, padding=True, return_tensors='pt')
inputs.to(self.cfg.MODEL.DEVICE)
with torch.no_grad():
outputs = self.forward(texts)
y_hat = torch.argmax(outputs[1], dim=-1)
expand_text_lens = torch.sum(inputs['attention_mask'], dim=-1) - 1
rst = []
for t_len, _y_hat in zip(expand_text_lens, y_hat):
rst.append(self.tokenizer.decode(_y_hat[1:t_len]).replace(' ', ''))
return rst