纠错任务
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

51 lines
2.1 KiB

2 years ago
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
@description:
"""
from abc import ABC
import torch.nn as nn
from transformers import BertForMaskedLM
from pycorrector.macbert.base_model import CscTrainingModel, FocalLoss
class MacBert4Csc(CscTrainingModel, ABC):
def __init__(self, cfg, tokenizer):
super().__init__(cfg)
self.cfg = cfg
self.bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT)
self.detection = nn.Linear(self.bert.config.hidden_size, 1)
self.sigmoid = nn.Sigmoid()
self.tokenizer = tokenizer
def forward(self, texts, cor_labels=None, det_labels=None):
if cor_labels:
text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids']
text_labels[text_labels == 0] = -100 # -100计算损失时会忽略
text_labels = text_labels.to(self.device)
else:
text_labels = None
encoded_text = self.tokenizer(texts, padding=True, return_tensors='pt')
encoded_text.to(self.device)
bert_outputs = self.bert(**encoded_text, labels=text_labels, return_dict=True, output_hidden_states=True)
# 检错概率
prob = self.detection(bert_outputs.hidden_states[-1])
if text_labels is None:
# 检错输出,纠错输出
outputs = (prob, bert_outputs.logits)
else:
det_loss_fct = FocalLoss(num_labels=None, activation_type='sigmoid')
# pad部分不计算损失
active_loss = encoded_text['attention_mask'].view(-1, prob.shape[1]) == 1
active_probs = prob.view(-1, prob.shape[1])[active_loss]
active_labels = det_labels[active_loss]
det_loss = det_loss_fct(active_probs, active_labels.float())
# 检错loss,纠错loss,检错输出,纠错输出
outputs = (det_loss,
bert_outputs.loss,
self.sigmoid(prob).squeeze(-1),
bert_outputs.logits)
return outputs