You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
6.2 KiB
152 lines
6.2 KiB
![]()
2 years ago
|
"""
|
||
|
@Time : 2021-01-21 12:00:59
|
||
|
@File : modeling_soft_masked_bert.py
|
||
|
@Author : Abtion
|
||
|
@Email : abtion{at}outlook.com
|
||
|
"""
|
||
|
from abc import ABC
|
||
|
from collections import OrderedDict
|
||
|
import transformers as tfs
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertOnlyMLMHead
|
||
|
from transformers.modeling_utils import ModuleUtilsMixin
|
||
|
from pycorrector.macbert.base_model import CscTrainingModel
|
||
|
|
||
|
|
||
|
class DetectionNetwork(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.gru = nn.GRU(
|
||
|
self.config.hidden_size,
|
||
|
self.config.hidden_size // 2,
|
||
|
num_layers=2,
|
||
|
batch_first=True,
|
||
|
dropout=self.config.hidden_dropout_prob,
|
||
|
bidirectional=True,
|
||
|
)
|
||
|
self.sigmoid = nn.Sigmoid()
|
||
|
self.linear = nn.Linear(self.config.hidden_size, 1)
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
out, _ = self.gru(hidden_states)
|
||
|
prob = self.linear(out)
|
||
|
prob = self.sigmoid(prob)
|
||
|
return prob
|
||
|
|
||
|
|
||
|
class CorrectionNetwork(torch.nn.Module, ModuleUtilsMixin):
|
||
|
def __init__(self, config, tokenizer, device):
|
||
|
super().__init__()
|
||
|
self.config = config
|
||
|
self.tokenizer = tokenizer
|
||
|
self.embeddings = BertEmbeddings(self.config)
|
||
|
self.bert = BertEncoder(self.config)
|
||
|
self.mask_token_id = self.tokenizer.mask_token_id
|
||
|
self.cls = BertOnlyMLMHead(self.config)
|
||
|
self._device = device
|
||
|
|
||
|
def forward(self, texts, prob, embed=None, cor_labels=None, residual_connection=False):
|
||
|
if cor_labels is not None:
|
||
|
text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids']
|
||
|
# torch的cross entropy loss 会忽略-100的label
|
||
|
text_labels[text_labels == 0] = -100
|
||
|
text_labels = text_labels.to(self._device)
|
||
|
else:
|
||
|
text_labels = None
|
||
|
encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt')
|
||
|
encoded_texts.to(self._device)
|
||
|
if embed is None:
|
||
|
embed = self.embeddings(input_ids=encoded_texts['input_ids'],
|
||
|
token_type_ids=encoded_texts['token_type_ids'])
|
||
|
# 此处较原文有一定改动,做此改动意在完整保留type_ids及position_ids的embedding。
|
||
|
mask_embed = self.embeddings(torch.ones_like(prob.squeeze(-1)).long() * self.mask_token_id).detach()
|
||
|
# 此处为原文实现
|
||
|
# mask_embed = self.embeddings(torch.tensor([[self.mask_token_id]], device=self._device)).detach()
|
||
|
cor_embed = prob * mask_embed + (1 - prob) * embed
|
||
|
|
||
|
input_shape = encoded_texts['input_ids'].size()
|
||
|
device = encoded_texts['input_ids'].device
|
||
|
|
||
|
extended_attention_mask = self.get_extended_attention_mask(encoded_texts['attention_mask'],
|
||
|
input_shape, device)
|
||
|
head_mask = self.get_head_mask(None, self.config.num_hidden_layers)
|
||
|
encoder_outputs = self.bert(
|
||
|
cor_embed,
|
||
|
attention_mask=extended_attention_mask,
|
||
|
head_mask=head_mask,
|
||
|
encoder_hidden_states=None,
|
||
|
encoder_attention_mask=None,
|
||
|
return_dict=False,
|
||
|
)
|
||
|
sequence_output = encoder_outputs[0]
|
||
|
|
||
|
sequence_output = sequence_output + embed if residual_connection else sequence_output
|
||
|
prediction_scores = self.cls(sequence_output)
|
||
|
out = (prediction_scores, sequence_output)
|
||
|
|
||
|
# Masked language modeling softmax layer
|
||
|
if text_labels is not None:
|
||
|
loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
|
||
|
cor_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), text_labels.view(-1))
|
||
|
out = (cor_loss,) + out
|
||
|
return out
|
||
|
|
||
|
def load_from_transformers_state_dict(self, gen_fp):
|
||
|
state_dict = OrderedDict()
|
||
|
gen_state_dict = tfs.AutoModelForMaskedLM.from_pretrained(gen_fp).state_dict()
|
||
|
for k, v in gen_state_dict.items():
|
||
|
name = k
|
||
|
if name.startswith('bert'):
|
||
|
name = name[5:]
|
||
|
if name.startswith('encoder'):
|
||
|
name = f'corrector.{name[8:]}'
|
||
|
if 'gamma' in name:
|
||
|
name = name.replace('gamma', 'weight')
|
||
|
if 'beta' in name:
|
||
|
name = name.replace('beta', 'bias')
|
||
|
state_dict[name] = v
|
||
|
self.load_state_dict(state_dict, strict=False)
|
||
|
|
||
|
|
||
|
class SoftMaskedBert4Csc(CscTrainingModel, ABC):
|
||
|
def __init__(self, cfg, tokenizer):
|
||
|
super().__init__(cfg)
|
||
|
self.cfg = cfg
|
||
|
self.config = tfs.AutoConfig.from_pretrained(cfg.MODEL.BERT_CKPT)
|
||
|
self.detector = DetectionNetwork(self.config)
|
||
|
self.tokenizer = tokenizer
|
||
|
self.corrector = CorrectionNetwork(self.config, tokenizer, cfg.MODEL.DEVICE)
|
||
|
self.corrector.load_from_transformers_state_dict(self.cfg.MODEL.BERT_CKPT)
|
||
|
self._device = cfg.MODEL.DEVICE
|
||
|
|
||
|
def forward(self, texts, cor_labels=None, det_labels=None):
|
||
|
encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt')
|
||
|
encoded_texts.to(self._device)
|
||
|
embed = self.corrector.embeddings(input_ids=encoded_texts['input_ids'],
|
||
|
token_type_ids=encoded_texts['token_type_ids'])
|
||
|
prob = self.detector(embed)
|
||
|
cor_out = self.corrector(texts, prob, embed, cor_labels, residual_connection=True)
|
||
|
|
||
|
if det_labels is not None:
|
||
|
det_loss_fct = nn.BCELoss()
|
||
|
# pad部分不计算损失
|
||
|
active_loss = encoded_texts['attention_mask'].view(-1, prob.shape[1]) == 1
|
||
|
active_probs = prob.view(-1, prob.shape[1])[active_loss]
|
||
|
active_labels = det_labels[active_loss]
|
||
|
det_loss = det_loss_fct(active_probs, active_labels.float())
|
||
|
outputs = (det_loss, cor_out[0], prob.squeeze(-1)) + cor_out[1:]
|
||
|
else:
|
||
|
outputs = (prob.squeeze(-1),) + cor_out
|
||
|
|
||
|
return outputs
|
||
|
|
||
|
def load_from_transformers_state_dict(self, gen_fp):
|
||
|
"""
|
||
|
从transformers加载预训练权重
|
||
|
:param gen_fp:
|
||
|
:return:
|
||
|
"""
|
||
|
self.corrector.load_from_transformers_state_dict(gen_fp)
|