You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
191 lines
7.1 KiB
191 lines
7.1 KiB
# -*- coding: utf-8 -*-
|
|
"""
|
|
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
|
|
@description:
|
|
"""
|
|
import operator
|
|
from abc import ABC
|
|
from loguru import logger
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import pytorch_lightning as pl
|
|
from pycorrector.macbert import lr_scheduler
|
|
from pycorrector.macbert.evaluate_util import compute_corrector_prf, compute_sentence_level_prf
|
|
|
|
|
|
class FocalLoss(nn.Module):
|
|
"""
|
|
Softmax and sigmoid focal loss.
|
|
copy from https://github.com/lonePatient/TorchBlocks
|
|
"""
|
|
|
|
def __init__(self, num_labels, activation_type='softmax', gamma=2.0, alpha=0.25, epsilon=1.e-9):
|
|
|
|
super(FocalLoss, self).__init__()
|
|
self.num_labels = num_labels
|
|
self.gamma = gamma
|
|
self.alpha = alpha
|
|
self.epsilon = epsilon
|
|
self.activation_type = activation_type
|
|
|
|
def forward(self, input, target):
|
|
"""
|
|
Args:
|
|
logits: model's output, shape of [batch_size, num_cls]
|
|
target: ground truth labels, shape of [batch_size]
|
|
Returns:
|
|
shape of [batch_size]
|
|
"""
|
|
if self.activation_type == 'softmax':
|
|
idx = target.view(-1, 1).long()
|
|
one_hot_key = torch.zeros(idx.size(0), self.num_labels, dtype=torch.float32, device=idx.device)
|
|
one_hot_key = one_hot_key.scatter_(1, idx, 1)
|
|
logits = torch.softmax(input, dim=-1)
|
|
loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
|
|
loss = loss.sum(1)
|
|
elif self.activation_type == 'sigmoid':
|
|
multi_hot_key = target
|
|
logits = torch.sigmoid(input)
|
|
zero_hot_key = 1 - multi_hot_key
|
|
loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
|
|
loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log()
|
|
return loss.mean()
|
|
|
|
|
|
def make_optimizer(cfg, model):
|
|
params = []
|
|
for key, value in model.named_parameters():
|
|
if not value.requires_grad:
|
|
continue
|
|
lr = cfg.SOLVER.BASE_LR
|
|
weight_decay = cfg.SOLVER.WEIGHT_DECAY
|
|
if "bias" in key:
|
|
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
|
|
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
|
|
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
|
|
if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
|
|
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
|
|
else:
|
|
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
|
|
return optimizer
|
|
|
|
|
|
def build_lr_scheduler(cfg, optimizer):
|
|
scheduler_args = {
|
|
"optimizer": optimizer,
|
|
|
|
# warmup options
|
|
"warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
|
|
"warmup_epochs": cfg.SOLVER.WARMUP_EPOCHS,
|
|
"warmup_method": cfg.SOLVER.WARMUP_METHOD,
|
|
|
|
# multi-step lr scheduler options
|
|
"milestones": cfg.SOLVER.STEPS,
|
|
"gamma": cfg.SOLVER.GAMMA,
|
|
|
|
# cosine annealing lr scheduler options
|
|
"max_iters": cfg.SOLVER.MAX_ITER,
|
|
"delay_iters": cfg.SOLVER.DELAY_ITERS,
|
|
"eta_min_lr": cfg.SOLVER.ETA_MIN_LR,
|
|
|
|
}
|
|
scheduler = getattr(lr_scheduler, cfg.SOLVER.SCHED)(**scheduler_args)
|
|
return {'scheduler': scheduler, 'interval': cfg.SOLVER.INTERVAL}
|
|
|
|
|
|
class BaseTrainingEngine(pl.LightningModule):
|
|
def __init__(self, cfg, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.cfg = cfg
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = make_optimizer(self.cfg, self)
|
|
scheduler = build_lr_scheduler(self.cfg, optimizer)
|
|
|
|
return [optimizer], [scheduler]
|
|
|
|
def on_validation_epoch_start(self) -> None:
|
|
logger.info('Valid.')
|
|
|
|
def on_test_epoch_start(self) -> None:
|
|
logger.info('Testing...')
|
|
|
|
|
|
class CscTrainingModel(BaseTrainingEngine, ABC):
|
|
"""
|
|
用于CSC的BaseModel, 定义了训练及预测步骤
|
|
"""
|
|
|
|
def __init__(self, cfg, *args, **kwargs):
|
|
super().__init__(cfg, *args, **kwargs)
|
|
# loss weight
|
|
self.w = cfg.MODEL.HYPER_PARAMS[0]
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
ori_text, cor_text, det_labels = batch
|
|
outputs = self.forward(ori_text, cor_text, det_labels)
|
|
loss = self.w * outputs[1] + (1 - self.w) * outputs[0]
|
|
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=len(ori_text))
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
ori_text, cor_text, det_labels = batch
|
|
outputs = self.forward(ori_text, cor_text, det_labels)
|
|
loss = self.w * outputs[1] + (1 - self.w) * outputs[0]
|
|
det_y_hat = (outputs[2] > 0.5).long()
|
|
cor_y_hat = torch.argmax((outputs[3]), dim=-1)
|
|
encoded_x = self.tokenizer(cor_text, padding=True, return_tensors='pt')
|
|
encoded_x.to(self._device)
|
|
cor_y = encoded_x['input_ids']
|
|
cor_y_hat *= encoded_x['attention_mask']
|
|
|
|
results = []
|
|
det_acc_labels = []
|
|
cor_acc_labels = []
|
|
for src, tgt, predict, det_predict, det_label in zip(ori_text, cor_y, cor_y_hat, det_y_hat, det_labels):
|
|
_src = self.tokenizer(src, add_special_tokens=False)['input_ids']
|
|
_tgt = tgt[1:len(_src) + 1].cpu().numpy().tolist()
|
|
_predict = predict[1:len(_src) + 1].cpu().numpy().tolist()
|
|
cor_acc_labels.append(1 if operator.eq(_tgt, _predict) else 0)
|
|
det_acc_labels.append(det_predict[1:len(_src) + 1].equal(det_label[1:len(_src) + 1]))
|
|
results.append((_src, _tgt, _predict,))
|
|
|
|
return loss.cpu().item(), det_acc_labels, cor_acc_labels, results
|
|
|
|
def validation_epoch_end(self, outputs) -> None:
|
|
det_acc_labels = []
|
|
cor_acc_labels = []
|
|
results = []
|
|
for out in outputs:
|
|
det_acc_labels += out[1]
|
|
cor_acc_labels += out[2]
|
|
results += out[3]
|
|
loss = np.mean([out[0] for out in outputs])
|
|
self.log('val_loss', loss)
|
|
logger.info(f'loss: {loss}')
|
|
logger.info(f'Detection:\n'
|
|
f'acc: {np.mean(det_acc_labels):.4f}')
|
|
logger.info(f'Correction:\n'
|
|
f'acc: {np.mean(cor_acc_labels):.4f}')
|
|
compute_corrector_prf(results, logger)
|
|
compute_sentence_level_prf(results, logger)
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
return self.validation_step(batch, batch_idx)
|
|
|
|
def test_epoch_end(self, outputs) -> None:
|
|
logger.info('Test.')
|
|
self.validation_epoch_end(outputs)
|
|
|
|
def predict(self, texts):
|
|
inputs = self.tokenizer(texts, padding=True, return_tensors='pt')
|
|
inputs.to(self.cfg.MODEL.DEVICE)
|
|
with torch.no_grad():
|
|
outputs = self.forward(texts)
|
|
y_hat = torch.argmax(outputs[1], dim=-1)
|
|
expand_text_lens = torch.sum(inputs['attention_mask'], dim=-1) - 1
|
|
rst = []
|
|
for t_len, _y_hat in zip(expand_text_lens, y_hat):
|
|
rst.append(self.tokenizer.decode(_y_hat[1:t_len]).replace(' ', ''))
|
|
return rst
|
|
|