From c7f4f8ff4fa5f78031b42f2ba103df980d347f34 Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Wed, 14 Jun 2023 14:24:30 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AC=AC=E4=B8=80=E6=AC=A1=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .idea/.gitignore | 9 + .idea/deployment.xml | 28 +++ .idea/inspectionProfiles/profiles_settings.xml | 6 + .idea/macbert.iml | 12 ++ .idea/misc.xml | 4 + .idea/modules.xml | 8 + .idea/vcs.xml | 7 + README.md | 237 +++++++++++++++++++++++ __init__.py | 0 base_model.py | 191 ++++++++++++++++++ ceshifenli.py | 94 +++++++++ correct_demo.py | 43 +++++ defaults.py | 114 +++++++++++ evaluate_util.py | 255 +++++++++++++++++++++++++ flask_macbert.py | 159 +++++++++++++++ infer.py | 116 +++++++++++ lr_scheduler.py | 178 +++++++++++++++++ macbert4csc.py | 50 +++++ macbert_corrector.py | 166 ++++++++++++++++ predict.py | 46 +++++ preprocess.py | 200 +++++++++++++++++++ reader.py | 68 +++++++ rewrite.log | 131 +++++++++++++ softmaskedbert4csc.py | 151 +++++++++++++++ train.py | 133 +++++++++++++ train_macbert4csc.yml | 24 +++ train_softmaskedbert4csc.yml | 24 +++ 27 files changed, 2454 insertions(+) create mode 100644 .idea/.gitignore create mode 100644 .idea/deployment.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/macbert.iml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 README.md create mode 100644 __init__.py create mode 100644 base_model.py create mode 100644 ceshifenli.py create mode 100644 correct_demo.py create mode 100644 defaults.py create mode 100644 evaluate_util.py create mode 100644 flask_macbert.py create mode 100644 infer.py create mode 100644 lr_scheduler.py create mode 100644 macbert4csc.py create mode 100644 macbert_corrector.py create mode 100644 predict.py create mode 100644 preprocess.py create mode 100644 reader.py create mode 100644 rewrite.log create mode 100644 softmaskedbert4csc.py create mode 100644 train.py create mode 100644 train_macbert4csc.yml create mode 100644 train_softmaskedbert4csc.yml diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..2b1ec50 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,9 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml +/macbert4csc-base-chinese \ No newline at end of file diff --git a/.idea/deployment.xml b/.idea/deployment.xml new file mode 100644 index 0000000..29bd65a --- /dev/null +++ b/.idea/deployment.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/macbert.iml b/.idea/macbert.iml new file mode 100644 index 0000000..7e114df --- /dev/null +++ b/.idea/macbert.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..d5d8cad --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..a3d9d16 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..ef99ac1 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..7b6a66f --- /dev/null +++ b/README.md @@ -0,0 +1,237 @@ +# MacBertMaskedLM For Correction +本项目是 MacBERT 改变网络结构的中文文本纠错模型,可支持 BERT 类模型为 backbone。 +> "MacBERT shares the same pre-training tasks as BERT with several modifications." —— (Cui et al., Findings of the EMNLP 2020) + +MacBert4csc 模型网络结构: ++ 本项目是 MacBERT 改变网络结构的中文文本纠错模型,可支持 BERT 类模型为 backbone。 ++ 在通常 BERT 模型上进行了魔改,追加了一个全连接层作为错误检测即 [detection](https://github.com/shibing624/pycorrector/blob/c0f31222b7849c452cc1ec207c71e9954bd6ca08/pycorrector/macbert/macbert4csc.py#L18), +与 SoftMaskedBERT 模型不同点在于,本项目中的 MacBERT 中,只是利用 detection 层和 correction 层的 loss 加权得到最终的 loss。不像 SoftmaskedBERT 中需要利用 detection 层的置信概率来作为 correction 的输入权重。 + +![macbert_network](https://github.com/shibing624/pycorrector/blob/master/docs/git_image/macbert_network.jpg) + +#### MacBERT 简介 +MacBERT 全称为 MLM as correction BERT,其中 MLM 指的是 masked language model。 +MacBERT 的模型网络结构上可以选择任意 BERT 类模型,其主要特征在于预训练时不同的 MLM task 设计: ++ 使用全词屏蔽 (wwm, whole-word masking) 以及 N-gram 屏蔽策略来选择 candidate tokens 进行屏蔽; ++ BERT 类模型通常使用 `[MASK]` 来屏蔽原词,而 MacBERT 使用第三方的同义词工具来为目标词生成近义词用于屏蔽原词,特别地,当原词没有近义词时,使用随机 n-gram 来屏蔽原词; ++ 和 BERT 类模型相似地,对于每个训练样本,输入中 80% 的词被替换成近义词(原为`[MASK]`)、10%的词替换为随机词,10%的词不变。 + +MLM as Correction Mask strategies: +![macbert_strategies](https://github.com/shibing624/pycorrector/blob/master/docs/git_image/macbert_mask_strategies.jpg) + + +## 使用说明 + +### 快速加载 + + +#### pycorrector调用 + +example: [correct_demo.py](correct_demo.py) + +```python +from pycorrector.macbert.macbert_corrector import MacBertCorrector + +nlp = MacBertCorrector("shibing624/macbert4csc-base-chinese").macbert_correct + +i = nlp('今天新情很好') +print(i) +``` +#### transformers调用 +当然,你也可使用官方的transformers库进行调用。 + +1.先pip安装transformers库: + +```shell +pip install transformers>=4.1.1 +``` +2.使用以下示例执行: + +```python +import operator +import torch +from transformers import BertTokenizerFast, BertForMaskedLM +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +tokenizer = BertTokenizerFast.from_pretrained("shibing624/macbert4csc-base-chinese") +model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese") +model.to(device) + +texts = ["今天新情很好", "你找到你最喜欢的工作,我也很高心。"] +with torch.no_grad(): + outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device)) + +def get_errors(corrected_text, origin_text): + sub_details = [] + for i, ori_char in enumerate(origin_text): + if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']: + # add unk word + corrected_text = corrected_text[:i] + ori_char + corrected_text[i:] + continue + if i >= len(corrected_text): + continue + if ori_char != corrected_text[i]: + if ori_char.lower() == corrected_text[i]: + # pass english upper char + corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] + continue + sub_details.append((ori_char, corrected_text[i], i, i + 1)) + sub_details = sorted(sub_details, key=operator.itemgetter(2)) + return corrected_text, sub_details + +result = [] +for ids, text in zip(outputs.logits, texts): + _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '') + corrected_text = _text[:len(text)] + corrected_text, details = get_errors(corrected_text, text) + print(text, ' => ', corrected_text, details) + result.append((corrected_text, details)) +print(result) +``` + +output: +```shell +今天新情很好 => 今天心情很好 [('新', '心', 2, 3)] +你找到你最喜欢的工作,我也很高心。 => 你找到你最喜欢的工作,我也很高兴。 [('心', '兴', 15, 16)] +``` + +模型文件组成: +``` +macbert4csc-base-chinese + ├── config.json + ├── added_tokens.json + ├── pytorch_model.bin + ├── special_tokens_map.json + ├── tokenizer_config.json + └── vocab.txt +``` + +## Evaluate + +提供评估脚本[pycorrector/utils/eval.py](../utils/eval.py),该脚本有两个功能: + +- 构建评估样本集:评估集[pycorrector/data/eval_corpus.json](../data/eval_corpus.json), + 包括字粒度错误100条、词粒度错误100条、语法错误100条,正确句子200条。用户可以修改条数生成其他评估样本分布。 +- 计算纠错准召率:采用保守计算方式,简单把纠错之后与正确句子完成匹配的视为正确,否则为错。 + +执行该评估脚本后, + +`shibing624/macbert4csc-base-chinese` 模型在 corpus500 纠错效果评估如下: + +- Sentence Level: acc:0.6560, precision:0.7797, recall:0.5919, f1:0.6730 + +规则方法(加入自定义混淆集)在corpus500纠错效果评估如下: + +- Sentence Level: acc:0.6400, recall:0.5067 + +`shibing624/macbert4csc-base-chinese` 在 SIGHAN2015 测试集纠错效果评估如下: + +- Char Level: precision:0.9372, recall:0.8640, f1:0.8991 +- Sentence Level: precision:0.8264, recall:0.7366, f1:0.7789 + +由于训练使用的数据使用了 SIGHAN2015 的训练集(复现paper),在 SIGHAN2015 的测试集上达到SOTA水平。 + +#### 评估case + +- run `python tests/macbert_corrector_test.py` + ![result](../../docs/git_image/macbert_result.jpg) +在 SIGHAN2015 的测试集上达到了SOTA水平。 + + +## 训练 + +### 安装依赖 +```shell +pip install transformers>=4.1.1 pytorch-lightning==1.4.9 torch>=1.7.0 yacs +``` +### 训练数据集 + +#### toy数据集(约1千条) +```shell +cd macbert +python preprocess.py +``` +得到toy数据集文件: +```shell +macbert/output +|-- dev.json +|-- test.json +`-- train.json +``` +#### SIGHAN+Wang271K中文纠错数据集 + + +| 数据集 | 语料 | 下载链接 | 压缩包大小 | +| :------- | :--------- | :---------: | :---------: | +| **`SIGHAN+Wang271K中文纠错数据集`** | SIGHAN+Wang271K(27万条) | [百度网盘(密码01b9)](https://pan.baidu.com/s/1BV5tr9eONZCI0wERFvr0gQ)
[shibing624/CSC](https://huggingface.co/datasets/shibing624/CSC)| 106M | +| **`原始SIGHAN数据集`** | SIGHAN13 14 15 | [官方csc.html](http://nlp.ee.ncu.edu.tw/resource/csc.html)| 339K | +| **`原始Wang271K数据集`** | Wang271K | [Automatic-Corpus-Generation dimmywang提供](https://github.com/wdimmy/Automatic-Corpus-Generation/blob/master/corpus/train.sgml)| 93M | + + +SIGHAN+Wang271K中文纠错数据集,数据格式: +```json +[ + { + "id": "B2-4029-3", + "original_text": "晚间会听到嗓音,白天的时候大家都不会太在意,但是在睡觉的时候这嗓音成为大家的恶梦。", + "wrong_ids": [ + 5, + 31 + ], + "correct_text": "晚间会听到噪音,白天的时候大家都不会太在意,但是在睡觉的时候这噪音成为大家的恶梦。" + } +] +``` + +下载`SIGHAN+Wang271K中文纠错数据集`,下载后新建output文件夹并放里面,文件位置同上。 + +#### 自有数据集 + +把自己数据集标注好,保存为跟训练样本集一样的json格式,然后加载模型继续训练即可。 + +1. 已有大量业务相关错误样本,主要标注错误位置(wrong_ids)和纠错后的句子(correct_text) +2. 没有现成的错误样本,可以手动写脚本生成错误样本(original_text),根据音似、形似等特征把正确句子的指定位置(wrong_ids)字符改为错字,附上 +第三方同音字生成脚本[同音词替换](https://github.com/dongrixinyu/JioNLP/wiki/%E6%95%B0%E6%8D%AE%E5%A2%9E%E5%BC%BA-%E8%AF%B4%E6%98%8E%E6%96%87%E6%A1%A3#%E5%90%8C%E9%9F%B3%E8%AF%8D%E6%9B%BF%E6%8D%A2) + +### 训练 MacBert4CSC 模型 +```shell +python train.py +``` + +注意:MacBert4CSC模型只能处理对齐文本的纠错问题,不能处理多字、少字的错误,所以训练集original_text需要和correct_text长度一样。 +否则会报错:“ValueError: Expected input batch_size (*A) to match target batch_size (*B).” + +### 预测 +- 方法一:直接加载保存的ckpt文件: +```shell +python infer.py +``` + +- 方法二:加载`pytorch_model.bin`文件: +把`output/macbert4csc`文件夹下以下模型文件复制到`~/.pycorrector/datasets/macbert_models/chinese_finetuned_correction`目录下, +就可以像上面`快速加载`使用pycorrector或者transformers调用。 + +```shell +output +└── macbert4csc + ├── config.json + ├── pytorch_model.bin + ├── special_tokens_map.json + ├── tokenizer_config.json + └── vocab.txt +``` + +demo示例[macbert_corrector.py](macbert_corrector.py): +``` +python3 macbert_corrector.py +``` + +### 训练 SoftMaskedBert4CSC 模型 +```shell +python train.py --config_file train_softmaskedbert4csc.yml +``` + +# Reference +- [BertBasedCorrectionModels](https://github.com/gitabtion/BertBasedCorrectionModels) +-
Cui, Y., Che, W., Liu, T., Qin, B., Wang, S., & Hu, G. (2020). Revisiting Pre-Trained Models for Chinese Natural Language Processing. Findings of the EMNLP, 657–668. https://doi.org/10.18653/v1/2020.findings-emnlp.58
(The publication for [MacBERT](https://arxiv.org/pdf/2004.13922.pdf)) + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/base_model.py b/base_model.py new file mode 100644 index 0000000..7ce80ca --- /dev/null +++ b/base_model.py @@ -0,0 +1,191 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com) +@description: +""" +import operator +from abc import ABC +from loguru import logger +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from pycorrector.macbert import lr_scheduler +from pycorrector.macbert.evaluate_util import compute_corrector_prf, compute_sentence_level_prf + + +class FocalLoss(nn.Module): + """ + Softmax and sigmoid focal loss. + copy from https://github.com/lonePatient/TorchBlocks + """ + + def __init__(self, num_labels, activation_type='softmax', gamma=2.0, alpha=0.25, epsilon=1.e-9): + + super(FocalLoss, self).__init__() + self.num_labels = num_labels + self.gamma = gamma + self.alpha = alpha + self.epsilon = epsilon + self.activation_type = activation_type + + def forward(self, input, target): + """ + Args: + logits: model's output, shape of [batch_size, num_cls] + target: ground truth labels, shape of [batch_size] + Returns: + shape of [batch_size] + """ + if self.activation_type == 'softmax': + idx = target.view(-1, 1).long() + one_hot_key = torch.zeros(idx.size(0), self.num_labels, dtype=torch.float32, device=idx.device) + one_hot_key = one_hot_key.scatter_(1, idx, 1) + logits = torch.softmax(input, dim=-1) + loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() + loss = loss.sum(1) + elif self.activation_type == 'sigmoid': + multi_hot_key = target + logits = torch.sigmoid(input) + zero_hot_key = 1 - multi_hot_key + loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() + loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log() + return loss.mean() + + +def make_optimizer(cfg, model): + params = [] + for key, value in model.named_parameters(): + if not value.requires_grad: + continue + lr = cfg.SOLVER.BASE_LR + weight_decay = cfg.SOLVER.WEIGHT_DECAY + if "bias" in key: + lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR + weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS + params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] + if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': + optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) + else: + optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) + return optimizer + + +def build_lr_scheduler(cfg, optimizer): + scheduler_args = { + "optimizer": optimizer, + + # warmup options + "warmup_factor": cfg.SOLVER.WARMUP_FACTOR, + "warmup_epochs": cfg.SOLVER.WARMUP_EPOCHS, + "warmup_method": cfg.SOLVER.WARMUP_METHOD, + + # multi-step lr scheduler options + "milestones": cfg.SOLVER.STEPS, + "gamma": cfg.SOLVER.GAMMA, + + # cosine annealing lr scheduler options + "max_iters": cfg.SOLVER.MAX_ITER, + "delay_iters": cfg.SOLVER.DELAY_ITERS, + "eta_min_lr": cfg.SOLVER.ETA_MIN_LR, + + } + scheduler = getattr(lr_scheduler, cfg.SOLVER.SCHED)(**scheduler_args) + return {'scheduler': scheduler, 'interval': cfg.SOLVER.INTERVAL} + + +class BaseTrainingEngine(pl.LightningModule): + def __init__(self, cfg, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cfg = cfg + + def configure_optimizers(self): + optimizer = make_optimizer(self.cfg, self) + scheduler = build_lr_scheduler(self.cfg, optimizer) + + return [optimizer], [scheduler] + + def on_validation_epoch_start(self) -> None: + logger.info('Valid.') + + def on_test_epoch_start(self) -> None: + logger.info('Testing...') + + +class CscTrainingModel(BaseTrainingEngine, ABC): + """ + 用于CSC的BaseModel, 定义了训练及预测步骤 + """ + + def __init__(self, cfg, *args, **kwargs): + super().__init__(cfg, *args, **kwargs) + # loss weight + self.w = cfg.MODEL.HYPER_PARAMS[0] + + def training_step(self, batch, batch_idx): + ori_text, cor_text, det_labels = batch + outputs = self.forward(ori_text, cor_text, det_labels) + loss = self.w * outputs[1] + (1 - self.w) * outputs[0] + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=len(ori_text)) + return loss + + def validation_step(self, batch, batch_idx): + ori_text, cor_text, det_labels = batch + outputs = self.forward(ori_text, cor_text, det_labels) + loss = self.w * outputs[1] + (1 - self.w) * outputs[0] + det_y_hat = (outputs[2] > 0.5).long() + cor_y_hat = torch.argmax((outputs[3]), dim=-1) + encoded_x = self.tokenizer(cor_text, padding=True, return_tensors='pt') + encoded_x.to(self._device) + cor_y = encoded_x['input_ids'] + cor_y_hat *= encoded_x['attention_mask'] + + results = [] + det_acc_labels = [] + cor_acc_labels = [] + for src, tgt, predict, det_predict, det_label in zip(ori_text, cor_y, cor_y_hat, det_y_hat, det_labels): + _src = self.tokenizer(src, add_special_tokens=False)['input_ids'] + _tgt = tgt[1:len(_src) + 1].cpu().numpy().tolist() + _predict = predict[1:len(_src) + 1].cpu().numpy().tolist() + cor_acc_labels.append(1 if operator.eq(_tgt, _predict) else 0) + det_acc_labels.append(det_predict[1:len(_src) + 1].equal(det_label[1:len(_src) + 1])) + results.append((_src, _tgt, _predict,)) + + return loss.cpu().item(), det_acc_labels, cor_acc_labels, results + + def validation_epoch_end(self, outputs) -> None: + det_acc_labels = [] + cor_acc_labels = [] + results = [] + for out in outputs: + det_acc_labels += out[1] + cor_acc_labels += out[2] + results += out[3] + loss = np.mean([out[0] for out in outputs]) + self.log('val_loss', loss) + logger.info(f'loss: {loss}') + logger.info(f'Detection:\n' + f'acc: {np.mean(det_acc_labels):.4f}') + logger.info(f'Correction:\n' + f'acc: {np.mean(cor_acc_labels):.4f}') + compute_corrector_prf(results, logger) + compute_sentence_level_prf(results, logger) + + def test_step(self, batch, batch_idx): + return self.validation_step(batch, batch_idx) + + def test_epoch_end(self, outputs) -> None: + logger.info('Test.') + self.validation_epoch_end(outputs) + + def predict(self, texts): + inputs = self.tokenizer(texts, padding=True, return_tensors='pt') + inputs.to(self.cfg.MODEL.DEVICE) + with torch.no_grad(): + outputs = self.forward(texts) + y_hat = torch.argmax(outputs[1], dim=-1) + expand_text_lens = torch.sum(inputs['attention_mask'], dim=-1) - 1 + rst = [] + for t_len, _y_hat in zip(expand_text_lens, y_hat): + rst.append(self.tokenizer.decode(_y_hat[1:t_len]).replace(' ', '')) + return rst diff --git a/ceshifenli.py b/ceshifenli.py new file mode 100644 index 0000000..f4635c5 --- /dev/null +++ b/ceshifenli.py @@ -0,0 +1,94 @@ +import unicodedata +def is_chinese(char): + if 'CJK' in unicodedata.name(char): + return True + else: + return False + + + +a = "ab1我们12是一个" + +b = [""] *len(a) + +last_post = False + +c = [] +for i, d in enumerate(a): + bool_ = is_chinese(d) + if bool_ == False: + b[i] = d + last_post = False + else: + if last_post == False: + c.append([(i,d)]) + else: + c[-1].append((i,d)) + last_post = True +print(c) +print(b) + +d = [] +for i in c: + d.append("".join([j[1] for j in i])) +print(d) + +e = d + +f = "" +for i in e: + f += i +f_list = list(f) +print(f_list) + +for i,d in enumerate(b): + if d == "": + zi = f_list.pop(0) + print(zi) + b[i] = zi +print(b) + +class SentenceUlit: + def __init__(self,sentence): + self.sentence = sentence + self.sentence_list = [""] * len(sentence) + self.last_post = False + self.sentence_batch = [] + self.pre_ulit() + self.inf_sentence_batch_str = "" + + + def is_chinese(self, char): + if 'CJK' in unicodedata.name(char): + return True + else: + return False + + def pre_ulit(self): + for i, d in enumerate(self.sentence): + bool_ = is_chinese(d) + if bool_ == False: + self.sentence_list[i] = d + self.last_post = False + else: + if self.last_post == False: + self.sentence_batch.append(d) + else: + self.sentence_batch[-1] += d + self.last_post = True + + def inf_ulit(self, sen): + for i in sen: + self.inf_sentence_batch_str += i + self.inf_sentence_batch_srt_list = list(self.inf_sentence_batch_str) + + for i, d in enumerate(self.sentence_list): + if d == "": + zi = self.inf_sentence_batch_srt_list.pop(0) + self.sentence_list[i] = zi + + +sen = SentenceUlit("ab1我们12是一个") + +print(sen.sentence_batch) +print(sen.sentence_list) \ No newline at end of file diff --git a/correct_demo.py b/correct_demo.py new file mode 100644 index 0000000..16bf6b5 --- /dev/null +++ b/correct_demo.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2021-02-03 21:57:15 +@File : correct_demo.py +@Author : Abtion +@Email : abtion{at}outlook.com +""" +import argparse +import sys + +from pycorrector.macbert.macbert_corrector import MacBertCorrector +from pycorrector import config + + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--macbert_model_dir", default=config.macbert_model_dir, + type=str, + help="MacBert pre-trained model dir") + args = parser.parse_args() + + nlp = MacBertCorrector(args.macbert_model_dir).macbert_correct + + i = nlp('今天新情很好') + print(i) + + i = nlp('少先队员英该为老人让座') + print(i) + + i = nlp('机器学习是人工智能领遇最能体现智能的一个分知。') + print(i) + + i = nlp('机其学习是人工智能领遇最能体现智能的一个分知。') + print(i) + + print(nlp('老是较书。')) + print(nlp('遇到一位很棒的奴生跟我聊天。')) + + +if __name__ == "__main__": + main() diff --git a/defaults.py b/defaults.py new file mode 100644 index 0000000..5783a63 --- /dev/null +++ b/defaults.py @@ -0,0 +1,114 @@ +""" +@Time : 2021-01-21 10:37:36 +@File : defaults.py +@Author : Abtion +@Email : abtion{at}outlook.com +""" +from yacs.config import CfgNode as CN + +# ----------------------------------------------------------------------------- +# Convention about Training / Test specific parameters +# ----------------------------------------------------------------------------- +# Whenever an argument can be either used for training or for testing, the +# corresponding name will be post-fixed by a _TRAIN for a training parameter, +# or _TEST for a test-specific parameter. +# For example, the number of images during training will be +# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be +# IMAGES_PER_BATCH_TEST + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = CN() + +_C.MODEL = CN() +_C.MODEL.DEVICE = "cpu" +_C.MODEL.GPU_IDS = [0] +_C.MODEL.NUM_CLASSES = 10 +_C.MODEL.BERT_CKPT = 'bert-base-chinese' +_C.MODEL.NAME = '' +_C.MODEL.WEIGHTS = '' +_C.MODEL.HYPER_PARAMS = [] + +# ----------------------------------------------------------------------------- +# INPUT +# ----------------------------------------------------------------------------- +_C.INPUT = CN() +# Max length of input text. +_C.INPUT.MAX_LEN = 512 + + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASETS = CN() +# List of the dataset names for training, as present in paths_catalog.py +_C.DATASETS.TRAIN = "" +# List of the dataset names for validation, as present in paths_catalog.py +_C.DATASETS.VALID = "" +# List of the dataset names for testing, as present in paths_catalog.py +_C.DATASETS.TEST = "" + +# ----------------------------------------------------------------------------- +# DataLoader +# ----------------------------------------------------------------------------- +_C.DATALOADER = CN() +# Number of data loading threads +_C.DATALOADER.NUM_WORKERS = 4 + +# ---------------------------------------------------------------------------- # +# Solver +# ---------------------------------------------------------------------------- # +_C.SOLVER = CN() +_C.SOLVER.OPTIMIZER_NAME = "AdamW" + +_C.SOLVER.MAX_EPOCHS = 50 + +_C.SOLVER.BASE_LR = 0.001 +_C.SOLVER.BIAS_LR_FACTOR = 2 + +_C.SOLVER.MOMENTUM = 0.9 + +_C.SOLVER.WEIGHT_DECAY = 0.0005 +_C.SOLVER.WEIGHT_DECAY_BIAS = 0 + +_C.SOLVER.GAMMA = 0.9999 +_C.SOLVER.STEPS = (10,) +_C.SOLVER.SCHED = "WarmupExponentialLR" +_C.SOLVER.WARMUP_FACTOR = 0.01 +_C.SOLVER.WARMUP_ITERS = 2 +_C.SOLVER.WARMUP_EPOCHS = 1024 +_C.SOLVER.WARMUP_METHOD = "linear" +_C.SOLVER.DELAY_ITERS = 0 +_C.SOLVER.ETA_MIN_LR = 3e-7 +_C.SOLVER.MAX_ITER = 10 +_C.SOLVER.INTERVAL = 'step' + +_C.SOLVER.CHECKPOINT_PERIOD = 10 +_C.SOLVER.LOG_PERIOD = 100 +_C.SOLVER.ACCUMULATE_GRAD_BATCHES = 1 +# Number of images per batch +# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will +# see 2 images per batch +_C.SOLVER.BATCH_SIZE = 16 + + +_C.TEST = CN() +_C.TEST.BATCH_SIZE = 8 +_C.TEST.CKPT_FN = "" + +# ---------------------------------------------------------------------------- # +# Task specific +# ---------------------------------------------------------------------------- # +_C.TASK = CN() +_C.TASK.NAME = "CSC" + + +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # +_C.OUTPUT_DIR = "" +_C.MODE = ['train', 'test'] + + diff --git a/evaluate_util.py b/evaluate_util.py new file mode 100644 index 0000000..498acd3 --- /dev/null +++ b/evaluate_util.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com) +@description: +""" + + +def compute_corrector_prf(results, logger): + """ + copy from https://github.com/sunnyqiny/Confusionset-guided-Pointer-Networks-for-Chinese-Spelling-Check/blob/master/utils/evaluation_metrics.py + """ + TP = 0 + FP = 0 + FN = 0 + all_predict_true_index = [] + all_gold_index = [] + for item in results: + src, tgt, predict = item + gold_index = [] + each_true_index = [] + for i in range(len(list(src))): + if src[i] == tgt[i]: + continue + else: + gold_index.append(i) + all_gold_index.append(gold_index) + predict_index = [] + for i in range(len(list(src))): + if src[i] == predict[i]: + continue + else: + predict_index.append(i) + + for i in predict_index: + if i in gold_index: + TP += 1 + each_true_index.append(i) + else: + FP += 1 + for i in gold_index: + if i in predict_index: + continue + else: + FN += 1 + all_predict_true_index.append(each_true_index) + + # For the detection Precision, Recall and F1 + detection_precision = TP / (TP + FP) if (TP + FP) > 0 else 0 + detection_recall = TP / (TP + FN) if (TP + FN) > 0 else 0 + if detection_precision + detection_recall == 0: + detection_f1 = 0 + else: + detection_f1 = 2 * (detection_precision * detection_recall) / (detection_precision + detection_recall) + logger.info( + "The detection result is precision={}, recall={} and F1={}".format(detection_precision, detection_recall, + detection_f1)) + + TP = 0 + FP = 0 + FN = 0 + + for i in range(len(all_predict_true_index)): + # we only detect those correctly detected location, which is a different from the common metrics since + # we wanna to see the precision improve by using the confusionset + if len(all_predict_true_index[i]) > 0: + predict_words = [] + for j in all_predict_true_index[i]: + predict_words.append(results[i][2][j]) + if results[i][1][j] == results[i][2][j]: + TP += 1 + else: + FP += 1 + for j in all_gold_index[i]: + if results[i][1][j] in predict_words: + continue + else: + FN += 1 + + # For the correction Precision, Recall and F1 + correction_precision = TP / (TP + FP) if (TP + FP) > 0 else 0 + correction_recall = TP / (TP + FN) if (TP + FN) > 0 else 0 + if correction_precision + correction_recall == 0: + correction_f1 = 0 + else: + correction_f1 = 2 * (correction_precision * correction_recall) / (correction_precision + correction_recall) + logger.info("The correction result is precision={}, recall={} and F1={}".format(correction_precision, + correction_recall, + correction_f1)) + + return detection_f1, correction_f1 + + +def compute_sentence_level_prf(results, logger): + """ + 自定义的句级prf,设定需要纠错为正样本,无需纠错为负样本 + :param results: + :return: + """ + + TP = 0.0 + FP = 0.0 + FN = 0.0 + TN = 0.0 + total_num = len(results) + + for item in results: + src, tgt, predict = item + + # 负样本 + if src == tgt: + # 预测也为负 + if tgt == predict: + TN += 1 + # 预测为正 + else: + FP += 1 + # 正样本 + else: + # 预测也为正 + if tgt == predict: + TP += 1 + # 预测为负 + else: + FN += 1 + + acc = (TP + TN) / total_num + precision = TP / (TP + FP) if TP > 0 else 0.0 + recall = TP / (TP + FN) if TP > 0 else 0.0 + f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0 + + logger.info(f'Sentence Level: acc:{acc:.6f}, precision:{precision:.6f}, recall:{recall:.6f}, f1:{f1:.6f}') + return acc, precision, recall, f1 + + +def report_prf(tp, fp, fn, phase, logger=None, return_dict=False): + # For the detection Precision, Recall and F1 + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + if precision + recall == 0: + f1_score = 0 + else: + f1_score = 2 * (precision * recall) / (precision + recall) + + if phase and logger: + logger.info(f"The {phase} result is: " + f"{precision:.4f}/{recall:.4f}/{f1_score:.4f} -->\n" + # f"precision={precision:.6f}, recall={recall:.6f} and F1={f1_score:.6f}\n" + f"support: TP={tp}, FP={fp}, FN={fn}") + if return_dict: + ret_dict = { + f'{phase}_p': precision, + f'{phase}_r': recall, + f'{phase}_f1': f1_score} + return ret_dict + return precision, recall, f1_score + + +def compute_corrector_prf_faspell(results, logger=None, strict=True): + """ + All-in-one measure function. + based on FASpell's measure script. + :param results: a list of (wrong, correct, predict, ...) + both token_ids or characters are fine for the script. + :param logger: take which logger to print logs. + :param strict: a more strict evaluation mode (all-char-detected/corrected) + References: + sentence-level PRF: https://github.com/iqiyi/ + FASPell/blob/master/faspell.py + """ + + corrected_char, wrong_char = 0, 0 + corrected_sent, wrong_sent = 0, 0 + true_corrected_char = 0 + true_corrected_sent = 0 + true_detected_char = 0 + true_detected_sent = 0 + accurate_detected_sent = 0 + accurate_corrected_sent = 0 + all_sent = 0 + + for item in results: + # wrong, correct, predict, d_tgt, d_predict = item + wrong, correct, predict = item[:3] + + all_sent += 1 + wrong_num = 0 + corrected_num = 0 + original_wrong_num = 0 + true_detected_char_in_sentence = 0 + + for c, w, p in zip(correct, wrong, predict): + if c != p: + wrong_num += 1 + if w != p: + corrected_num += 1 + if c == p: + true_corrected_char += 1 + if w != c: + true_detected_char += 1 + true_detected_char_in_sentence += 1 + if c != w: + original_wrong_num += 1 + + corrected_char += corrected_num + wrong_char += original_wrong_num + if original_wrong_num != 0: + wrong_sent += 1 + if corrected_num != 0 and wrong_num == 0: + true_corrected_sent += 1 + + if corrected_num != 0: + corrected_sent += 1 + + if strict: # find out all faulty wordings' potisions + true_detected_flag = (true_detected_char_in_sentence == original_wrong_num \ + and original_wrong_num != 0 \ + and corrected_num == true_detected_char_in_sentence) + else: # think it has faulty wordings + true_detected_flag = (corrected_num != 0 and original_wrong_num != 0) + + # if corrected_num != 0 and original_wrong_num != 0: + if true_detected_flag: + true_detected_sent += 1 + if correct == predict: + accurate_corrected_sent += 1 + if correct == predict or true_detected_flag: + accurate_detected_sent += 1 + + counts = { # TP, FP, TN for each level + 'det_char_counts': [true_detected_char, + corrected_char - true_detected_char, + wrong_char - true_detected_char], + 'cor_char_counts': [true_corrected_char, + corrected_char - true_corrected_char, + wrong_char - true_corrected_char], + 'det_sent_counts': [true_detected_sent, + corrected_sent - true_detected_sent, + wrong_sent - true_detected_sent], + 'cor_sent_counts': [true_corrected_sent, + corrected_sent - true_corrected_sent, + wrong_sent - true_corrected_sent], + 'det_sent_acc': accurate_detected_sent / all_sent, + 'cor_sent_acc': accurate_corrected_sent / all_sent, + 'all_sent_count': all_sent, + } + + details = {} + for phase in ['det_char', 'cor_char', 'det_sent', 'cor_sent']: + dic = report_prf( + *counts[f'{phase}_counts'], + phase=phase, logger=logger, + return_dict=True) + details.update(dic) + details.update(counts) + return details diff --git a/flask_macbert.py b/flask_macbert.py new file mode 100644 index 0000000..818afb4 --- /dev/null +++ b/flask_macbert.py @@ -0,0 +1,159 @@ +import os +from flask import Flask, jsonify +from flask import request +import operator +import torch +from transformers import BertTokenizerFast, BertForMaskedLM +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +import uuid +import json +from threading import Thread +import time +import re +import logging +import unicodedata + + +logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 + filename='rewrite.log', + filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 + # a是追加模式,默认如果不写的话,就是追加模式 + format= + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + # 日志格式 + ) +db_key_query = 'query' +batch_size = 32 +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False +import logging + +pattern = r"[。]" +RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") +fuhao_end_sentence = ["。", ",", "?", "!", "…"] + +tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese") +model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese") +model.to(device) + + +def is_chinese(char): + if 'CJK' in unicodedata.name(char): + return True + else: + return False + + +class SentenceUlit: + def __init__(self, sentence): + self.sentence = sentence + self.sentence_list = [""] * len(sentence) + self.last_post = False + self.sentence_batch = [] + self.pre_ulit() + self.inf_sentence_batch_str = "" + + def is_chinese(self, char): + if 'CJK' in unicodedata.name(char): + return True + else: + return False + + def pre_ulit(self): + for i, d in enumerate(self.sentence): + bool_ = is_chinese(d) + if bool_ == False: + self.sentence_list[i] = d + self.last_post = False + else: + if self.last_post == False: + self.sentence_batch.append(d) + else: + self.sentence_batch[-1] += d + self.last_post = True + + def inf_ulit(self, sen): + for i in sen: + self.inf_sentence_batch_str += i + self.inf_sentence_batch_srt_list = list(self.inf_sentence_batch_str) + + for i, d in enumerate(self.sentence_list): + if d == "": + zi = self.inf_sentence_batch_srt_list.pop(0) + self.sentence_list[i] = zi + + +class log: + def __init__(self): + pass + + def log(*args, **kwargs): + format = '%Y/%m/%d-%H:%M:%S' + format_h = '%Y-%m-%d' + value = time.localtime(int(time.time())) + dt = time.strftime(format, value) + dt_log_file = time.strftime(format_h, value) + log_file = 'log_file/access-%s' % dt_log_file + ".log" + if not os.path.exists(log_file): + with open(os.path.join(log_file), 'w', encoding='utf-8') as f: + print(dt, *args, file=f, **kwargs) + else: + with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: + print(dt, *args, file=f, **kwargs) + +def get_errors(corrected_text, origin_text): + sub_details = [] + for i, ori_char in enumerate(origin_text): + if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']: + # add unk word + corrected_text = corrected_text[:i] + ori_char + corrected_text[i:] + continue + if i >= len(corrected_text): + continue + if ori_char != corrected_text[i]: + if ori_char.lower() == corrected_text[i]: + # pass english upper char + corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] + continue + sub_details.append((ori_char, corrected_text[i], i, i + 1)) + sub_details = sorted(sub_details, key=operator.itemgetter(2)) + return corrected_text, sub_details + + +def main(texts): + with torch.no_grad(): + outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device)) + + result = [] + print(outputs.logits) + for ids, text in zip(outputs.logits, texts): + + _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '') + corrected_text = _text[:len(text)] + print(corrected_text) + corrected_text, details = get_errors(corrected_text, text) + result.append({"old": text, + "new": corrected_text, + "re_pos": details}) + return result + + +@app.route("/predict", methods=["POST"]) +def handle_query(): + print(request.remote_addr) + texts = request.json["texts"] + return_list = main(texts) + return_text = {"resilt": return_list, "probabilities": None, "status_code": 200} + return jsonify(return_text) # 返回结果 + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 + filename='rewrite.log', + filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 + # a是追加模式,默认如果不写的话,就是追加模式 + format= + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + # 日志格式 + ) + app.run(host="0.0.0.0", port=16000, threaded=True, debug=False) diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..d9a8d02 --- /dev/null +++ b/infer.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com), okcd00(okcd00@qq.com) +@description: +""" +import sys +import torch +import argparse +from transformers import BertTokenizerFast +from loguru import logger +sys.path.append('../..') + +from pycorrector.macbert.macbert4csc import MacBert4Csc +from pycorrector.macbert.softmaskedbert4csc import SoftMaskedBert4Csc +from pycorrector.macbert.macbert_corrector import get_errors +from pycorrector.macbert.defaults import _C as cfg + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Inference: + def __init__(self, ckpt_path='output/macbert4csc/epoch=09-val_loss=0.01.ckpt', + vocab_path='output/macbert4csc/vocab.txt', + cfg_path='train_macbert4csc.yml'): + logger.debug("device: {}".format(device)) + self.tokenizer = BertTokenizerFast.from_pretrained(vocab_path) + cfg.merge_from_file(cfg_path) + + if 'macbert4csc' in cfg_path: + self.model = MacBert4Csc.load_from_checkpoint(checkpoint_path=ckpt_path, + cfg=cfg, + map_location=device, + tokenizer=self.tokenizer) + elif 'softmaskedbert4csc' in cfg_path: + self.model = SoftMaskedBert4Csc.load_from_checkpoint(checkpoint_path=ckpt_path, + cfg=cfg, + map_location=device, + tokenizer=self.tokenizer) + else: + raise ValueError("model not found.") + self.model.to(device) + self.model.eval() + + def predict(self, sentence_list): + """ + 文本纠错模型预测 + Args: + sentence_list: list + 输入文本列表 + Returns: tuple + corrected_texts(list) + """ + is_str = False + if isinstance(sentence_list, str): + is_str = True + sentence_list = [sentence_list] + corrected_texts = self.model.predict(sentence_list) + if is_str: + return corrected_texts[0] + return corrected_texts + + def predict_with_error_detail(self, sentence_list): + """ + 文本纠错模型预测,结果带错误位置信息 + Args: + sentence_list: list + 输入文本列表 + Returns: tuple + corrected_texts(list), details(list) + """ + details = [] + is_str = False + if isinstance(sentence_list, str): + is_str = True + sentence_list = [sentence_list] + corrected_texts = self.model.predict(sentence_list) + + for corrected_text, text in zip(corrected_texts, sentence_list): + corrected_text, sub_details = get_errors(corrected_text, text) + details.append(sub_details) + if is_str: + return corrected_texts[0], details[0] + return corrected_texts, details + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="infer") + parser.add_argument("--ckpt_path", default="output/macbert4csc/epoch=09-val_loss=0.01.ckpt", + help="path to config file", type=str) + parser.add_argument("--vocab_path", default="output/macbert4csc/vocab.txt", help="path to config file", type=str) + parser.add_argument("--config_file", default="train_macbert4csc.yml", help="path to config file", type=str) + args = parser.parse_args() + m = Inference(args.ckpt_path, args.vocab_path, args.config_file) + inputs = [ + '它的本领是呼风唤雨,因此能灭火防灾。狎鱼后面是獬豸。獬豸通常头上长着独角,有时又被称为独角羊。它很聪彗,而且明辨是非,象征着大公无私,又能镇压斜恶。', + '老是较书。', + '少先队 员因该 为老人让 坐', + '感谢等五分以后,碰到一位很棒的奴生跟我可聊。', + '遇到一位很棒的奴生跟我聊天。', + '遇到一位很美的女生跟我疗天。', + '他们只能有两个选择:接受降新或自动离职。', + '王天华开心得一直说话。', + '你说:“怎么办?”我怎么知道?', + ] + outputs = m.predict(inputs) + for a, b in zip(inputs, outputs): + print('input :', a) + print('predict:', b) + print() + + # 在sighan2015 test数据集评估模型 + # macbert4csc Sentence Level: acc:0.7845, precision:0.8174, recall:0.7256, f1:0.7688, cost time:10.79 s + # softmaskedbert4csc Sentence Level: acc:0.6964, precision:0.8065, recall:0.5064, f1:0.6222, cost time:16.20 s + from pycorrector.utils.eval import eval_sighan2015_by_model + + eval_sighan2015_by_model(m.predict_with_error_detail) diff --git a/lr_scheduler.py b/lr_scheduler.py new file mode 100644 index 0000000..fd203da --- /dev/null +++ b/lr_scheduler.py @@ -0,0 +1,178 @@ +""" +@Time : 2021-01-21 10:52:47 +@File : lr_scheduler.py +@Author : Abtion +@Email : abtion{at}outlook.com +""" +import math +import warnings +from bisect import bisect_right +from typing import List + +import torch +from torch.optim.lr_scheduler import _LRScheduler + +__all__ = ["WarmupMultiStepLR", "WarmupCosineAnnealingLR"] + + +class WarmupMultiStepLR(_LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + milestones: List[int], + gamma: float = 0.1, + warmup_factor: float = 0.001, + warmup_epochs: int = 2, + warmup_method: str = "linear", + last_epoch: int = -1, + **kwargs, + ): + if not list(milestones) == sorted(milestones): + raise ValueError( + "Milestones should be a list of" " increasing integers. Got {}", milestones + ) + self.milestones = milestones + self.gamma = gamma + self.warmup_factor = warmup_factor + self.warmup_epochs = warmup_epochs + self.warmup_method = warmup_method + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + warmup_factor = _get_warmup_factor_at_iter( + self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor + ) + return [ + base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) + for base_lr in self.base_lrs + ] + + def _compute_values(self) -> List[float]: + # The new interface + return self.get_lr() + + +class WarmupExponentialLR(_LRScheduler): + """Decays the learning rate of each parameter group by gamma every epoch. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + gamma (float): Multiplicative factor of learning rate decay. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, gamma, last_epoch=-1, warmup_epochs=2, warmup_factor=1.0 / 3, verbose=False, + **kwargs): + self.gamma = gamma + self.warmup_method = 'linear' + self.warmup_epochs = warmup_epochs + self.warmup_factor = warmup_factor + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + warmup_factor = _get_warmup_factor_at_iter( + self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor + ) + + if self.last_epoch <= self.warmup_epochs: + return [base_lr * warmup_factor + for base_lr in self.base_lrs] + return [group['lr'] * self.gamma + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * self.gamma ** self.last_epoch + for base_lr in self.base_lrs] + + +class WarmupCosineAnnealingLR(_LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + + \cos(\frac{T_{cur}}{T_{max}}\pi)) + When last_epoch=-1, sets initial lr as lr. + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + Args: + optimizer (Optimizer): Wrapped optimizer. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + max_iters: int, + delay_iters: int = 0, + eta_min_lr: int = 0, + warmup_factor: float = 0.001, + warmup_epochs: int = 2, + warmup_method: str = "linear", + last_epoch=-1, + **kwargs + ): + self.max_iters = max_iters + self.delay_iters = delay_iters + self.eta_min_lr = eta_min_lr + self.warmup_factor = warmup_factor + self.warmup_epochs = warmup_epochs + self.warmup_method = warmup_method + assert self.delay_iters >= self.warmup_epochs, "Scheduler delay iters must be larger than warmup iters" + super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + if self.last_epoch <= self.warmup_epochs: + warmup_factor = _get_warmup_factor_at_iter( + self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor, + ) + return [ + base_lr * warmup_factor for base_lr in self.base_lrs + ] + elif self.last_epoch <= self.delay_iters: + return self.base_lrs + + else: + return [ + self.eta_min_lr + (base_lr - self.eta_min_lr) * + (1 + math.cos( + math.pi * (self.last_epoch - self.delay_iters) / (self.max_iters - self.delay_iters))) / 2 + for base_lr in self.base_lrs] + + +def _get_warmup_factor_at_iter( + method: str, iter: int, warmup_iters: int, warmup_factor: float +) -> float: + """ + Return the learning rate warmup factor at a specific iteration. + See https://arxiv.org/abs/1706.02677 for more details. + Args: + method (str): warmup method; either "constant" or "linear". + iter (int): iteration at which to calculate the warmup factor. + warmup_iters (int): the number of warmup iterations. + warmup_factor (float): the base warmup factor (the meaning changes according + to the method used). + Returns: + float: the effective warmup factor at the given iteration. + """ + if iter >= warmup_iters: + return 1.0 + + if method == "constant": + return warmup_factor + elif method == "linear": + alpha = iter / warmup_iters + return warmup_factor * (1 - alpha) + alpha + else: + raise ValueError("Unknown warmup method: {}".format(method)) diff --git a/macbert4csc.py b/macbert4csc.py new file mode 100644 index 0000000..60b34b1 --- /dev/null +++ b/macbert4csc.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com) +@description: +""" +from abc import ABC + +import torch.nn as nn +from transformers import BertForMaskedLM +from pycorrector.macbert.base_model import CscTrainingModel, FocalLoss + + +class MacBert4Csc(CscTrainingModel, ABC): + def __init__(self, cfg, tokenizer): + super().__init__(cfg) + self.cfg = cfg + self.bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT) + self.detection = nn.Linear(self.bert.config.hidden_size, 1) + self.sigmoid = nn.Sigmoid() + self.tokenizer = tokenizer + + def forward(self, texts, cor_labels=None, det_labels=None): + if cor_labels: + text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids'] + text_labels[text_labels == 0] = -100 # -100计算损失时会忽略 + text_labels = text_labels.to(self.device) + else: + text_labels = None + encoded_text = self.tokenizer(texts, padding=True, return_tensors='pt') + encoded_text.to(self.device) + bert_outputs = self.bert(**encoded_text, labels=text_labels, return_dict=True, output_hidden_states=True) + # 检错概率 + prob = self.detection(bert_outputs.hidden_states[-1]) + + if text_labels is None: + # 检错输出,纠错输出 + outputs = (prob, bert_outputs.logits) + else: + det_loss_fct = FocalLoss(num_labels=None, activation_type='sigmoid') + # pad部分不计算损失 + active_loss = encoded_text['attention_mask'].view(-1, prob.shape[1]) == 1 + active_probs = prob.view(-1, prob.shape[1])[active_loss] + active_labels = det_labels[active_loss] + det_loss = det_loss_fct(active_probs, active_labels.float()) + # 检错loss,纠错loss,检错输出,纠错输出 + outputs = (det_loss, + bert_outputs.loss, + self.sigmoid(prob).squeeze(-1), + bert_outputs.logits) + return outputs diff --git a/macbert_corrector.py b/macbert_corrector.py new file mode 100644 index 0000000..4bf5c4d --- /dev/null +++ b/macbert_corrector.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com) +@description: +""" +import operator +import sys +import time +import os +from transformers import BertTokenizerFast, BertForMaskedLM +import torch +from typing import List +from loguru import logger + +sys.path.append('../..') +from pycorrector import config +from pycorrector.utils.tokenizer import split_text_by_maxlen + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" +unk_tokens = [' ', '“', '”', '‘', '’', '\n', '…', '—', '擤', '\t', '֍', '玕', ''] + + +def get_errors(corrected_text, origin_text): + sub_details = [] + for i, ori_char in enumerate(origin_text): + if i >= len(corrected_text): + break + if ori_char in unk_tokens: + # deal with unk word + corrected_text = corrected_text[:i] + ori_char + corrected_text[i:] + continue + if ori_char != corrected_text[i]: + if ori_char.lower() == corrected_text[i]: + # pass english upper char + corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] + continue + sub_details.append((ori_char, corrected_text[i], i, i + 1)) + sub_details = sorted(sub_details, key=operator.itemgetter(2)) + return corrected_text, sub_details + + +class MacBertCorrector(object): + def __init__(self, model_dir=config.macbert_model_dir): + self.name = 'macbert_corrector' + t1 = time.time() + bin_path = os.path.join(model_dir, 'pytorch_model.bin') + if not os.path.exists(bin_path): + model_dir = "shibing624/macbert4csc-base-chinese" + logger.warning(f'local model {bin_path} not exists, use default HF model {model_dir}') + + self.tokenizer = BertTokenizerFast.from_pretrained(model_dir) + self.model = BertForMaskedLM.from_pretrained(model_dir) + self.model.to(device) + logger.debug("Use device: {}".format(device)) + logger.debug('Loaded macbert4csc model: %s, spend: %.3f s.' % (model_dir, time.time() - t1)) + + def macbert_correct(self, text, threshold=0.7, verbose=False): + """ + 句子纠错 + :param text: 句子文本 + :param threshold: 阈值 + :param verbose: 是否打印详细信息 + :return: corrected_text, list[list], [error_word, correct_word, begin_pos, end_pos] + """ + text_new = '' + details = [] + # 长句切分为短句 + blocks = split_text_by_maxlen(text, maxlen=128) + block_texts = [block[0] for block in blocks] + inputs = self.tokenizer(block_texts, padding=True, return_tensors='pt').to(device) + with torch.no_grad(): + outputs = self.model(**inputs) + + for ids, (text, idx) in zip(outputs.logits, blocks): + decode_tokens_new = self.tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).split(' ') + decode_tokens_old = self.tokenizer.decode(inputs['input_ids'][idx], skip_special_tokens=True).split(' ') + if len(decode_tokens_new) != len(decode_tokens_old): + continue + probs = torch.max(torch.softmax(ids, dim=-1), dim=-1)[0].cpu().numpy() + decode_tokens = '' + for i in range(len(decode_tokens_old)): + if probs[i + 1] >= threshold: + if verbose: + logger.debug( + f"word: {decode_tokens_old[i]}, prob: {probs[i + 1]}, new word: {decode_tokens_new[i]}") + decode_tokens += decode_tokens_new[i] + else: + decode_tokens += decode_tokens_old[i] + corrected_text = decode_tokens[:len(text)] + corrected_text, sub_details = get_errors(corrected_text, text) + text_new += corrected_text + sub_details = [(i[0], i[1], idx + i[2], idx + i[3]) for i in sub_details] + details.extend(sub_details) + return text_new, details + + def batch_macbert_correct(self, texts: List[str], max_length: int = 128): + """ + 句子纠错 + :param texts: list[str], sentence list + :param max_length: int, max length of each sentence + :return: corrected_text, list[list], [error_word, correct_word, begin_pos, end_pos] + """ + result = [] + + inputs = self.tokenizer(texts, padding=True, return_tensors='pt').to(device) + with torch.no_grad(): + outputs = self.model(**inputs) + for ids, (i, text) in zip(outputs.logits, enumerate(texts)): + text_new = '' + details = [] + corrected_text = self.tokenizer.decode((torch.argmax(ids, dim=-1) * inputs.attention_mask[i]), + skip_special_tokens=True).replace(' ', '') + corrected_text, sub_details = get_errors(corrected_text, text) + text_new += corrected_text + sub_details = [(i[0], i[1], i[2], i[3]) for i in sub_details] + details.extend(sub_details) + result.append([text_new, details]) + return result + + +if __name__ == "__main__": + m = MacBertCorrector() + error_sentences = [ + '内容提要——在知识产权学科领域里', + '疝気医院那好 为老人让坐,疝気专科百科问答', + '少先队员因该为老人让坐', + '少 先 队 员 因 该 为 老人让坐', + '机七学习是人工智能领遇最能体现智能的一个分知', + '今天心情很好', + '老是较书。', + '遇到一位很棒的奴生跟我聊天。', + '他的语说的很好,法语也不错', + '他法语说的很好,的语也不错', + '他们的吵翻很不错,再说他们做的咖喱鸡也好吃', + '影像小孩子想的快,学习管理的斑法', + '餐厅的换经费产适合约会', + '走路真的麻坊,我也没有喝的东西,在家汪了', + '因为爸爸在看录音机,所以我没得看', + '不过在许多传统国家,女人向未得到平等', + '妈妈说:"别趴地上了,快起来,你还吃饭吗?",我说:"好。"就扒起来了。', + '你说:“怎么办?”我怎么知道?', + '我父母们常常说:“那时候吃的东西太少,每天只能吃一顿饭。”想一想,人们都快要饿死,谁提出化肥和农药的污染。', + '这本新书《居里夫人传》将的很生动有趣', + '֍我喜欢吃鸡,公鸡、母鸡、白切鸡、乌鸡、紫燕鸡……֍新的食谱', + '注意:“跨类保护”不等于“全类保护”。', + '12.——对比文件中未公开的数值和对比文件中已经公开的中间值具有新颖性;', + '《著作权法》(2020修正)第23条:“自然人的作品,其发表权、本法第', + '三步检验法(三步检验标准)(three-step test):若要', + '①申请人应提交,顺应了国家“健全创新激励', + '①申请人应提交,太平天国领导人洪仁玕。', + ' 部分优先权:', + '实施其专利的行为(生产经营≠营利≠商业经营)', + '实施,i can speak chinese, can i spea english. ? hello.', + "我不唉“看 琅擤琊榜”", + ] + t1 = time.time() + for sent in error_sentences: + corrected_sent, err = m.macbert_correct(sent, 0.6) + print("original sentence:{} => {} err:{}".format(sent, corrected_sent, err)) + print('[single]spend time:', time.time() - t1) + t2 = time.time() + res = m.batch_macbert_correct(error_sentences) + for sent, r in zip(error_sentences, res): + print("original sentence:{} => {} err:{}".format(sent, r[0], r[1])) + print('[batch]spend time:', time.time() - t2) diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..d098052 --- /dev/null +++ b/predict.py @@ -0,0 +1,46 @@ +import operator +import torch +from transformers import BertTokenizerFast, BertForMaskedLM +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese") +model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese") +model.to(device) + +texts = ["今天新情很好,你找到你最喜欢的工作,我也很高心。", "今天新情很好,你找到你最喜欢的工作,我也很高心。"] +with torch.no_grad(): + input = tokenizer(texts, padding=True, return_tensors='pt').to(device) + print(input) + input_ids = input['input_ids'].to(device) + token_type_ids = input["token_type_ids"].to(device) + attention_mask = input['attention_mask'].to(device) + print() + outputs = model(input_ids,token_type_ids,attention_mask) + +def get_errors(corrected_text, origin_text): + sub_details = [] + for i, ori_char in enumerate(origin_text): + if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']: + # add unk word + corrected_text = corrected_text[:i] + ori_char + corrected_text[i:] + continue + if i >= len(corrected_text): + continue + if ori_char != corrected_text[i]: + if ori_char.lower() == corrected_text[i]: + # pass english upper char + corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] + continue + sub_details.append((ori_char, corrected_text[i], i, i + 1)) + sub_details = sorted(sub_details, key=operator.itemgetter(2)) + return corrected_text, sub_details + + +result = [] +for ids, text in zip(outputs.logits, texts): + _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '') + corrected_text = _text[:len(text)] + corrected_text, details = get_errors(corrected_text, text) + print(text, ' => ', corrected_text, details) + result.append((text, corrected_text, details)) +print(result) \ No newline at end of file diff --git a/preprocess.py b/preprocess.py new file mode 100644 index 0000000..46a9e89 --- /dev/null +++ b/preprocess.py @@ -0,0 +1,200 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com) +@description: +""" +import os +import sys +from codecs import open +from sklearn.model_selection import train_test_split +from lxml import etree +from xml.dom import minidom + +sys.path.append('../..') +from pycorrector.utils.io_utils import save_json +from pycorrector import traditional2simplified + +pwd_path = os.path.abspath(os.path.dirname(__file__)) + + +def read_data(fp): + for fn in os.listdir(fp): + if fn.endswith('ing.sgml'): + with open(os.path.join(fp, fn), 'r', encoding='utf-8', errors='ignore') as f: + item = [] + for line in f: + if line.strip().startswith(' 0: + yield ''.join(item) + item = [line.strip()] + elif line.strip().startswith('<'): + item.append(line.strip()) + + +def proc_item(item): + """ + 处理训练数据集 + Args: + item: + Returns: + list + """ + root = etree.XML(item) + passages = dict() + mistakes = [] + for passage in root.xpath('/ESSAY/TEXT/PASSAGE'): + passages[passage.get('id')] = traditional2simplified(passage.text) + for mistake in root.xpath('/ESSAY/MISTAKE'): + mistakes.append({'id': mistake.get('id'), + 'location': int(mistake.get('location')) - 1, + 'wrong': traditional2simplified(mistake.xpath('./WRONG/text()')[0].strip()), + 'correction': traditional2simplified(mistake.xpath('./CORRECTION/text()')[0].strip())}) + + rst_items = dict() + + def get_passages_by_id(pgs, _id): + p = pgs.get(_id) + if p: + return p + _id = _id[:-1] + str(int(_id[-1]) + 1) + p = pgs.get(_id) + if p: + return p + raise ValueError(f'passage not found by {_id}') + + for mistake in mistakes: + if mistake['id'] not in rst_items.keys(): + rst_items[mistake['id']] = {'original_text': get_passages_by_id(passages, mistake['id']), + 'wrong_ids': [], + 'correct_text': get_passages_by_id(passages, mistake['id'])} + + ori_text = rst_items[mistake['id']]['original_text'] + cor_text = rst_items[mistake['id']]['correct_text'] + if len(ori_text) == len(cor_text): + if ori_text[mistake['location']] in mistake['wrong']: + rst_items[mistake['id']]['wrong_ids'].append(mistake['location']) + wrong_char_idx = mistake['wrong'].index(ori_text[mistake['location']]) + start = mistake['location'] - wrong_char_idx + end = start + len(mistake['wrong']) + rst_items[mistake['id']][ + 'correct_text'] = f'{cor_text[:start]}{mistake["correction"]}{cor_text[end:]}' + else: + print(f'error line:\n{mistake["id"]}\n{ori_text}\n{cor_text}') + rst = [] + for k in rst_items.keys(): + if len(rst_items[k]['correct_text']) == len(rst_items[k]['original_text']): + rst.append({'id': k, **rst_items[k]}) + else: + text = rst_items[k]['correct_text'] + rst.append({'id': k, 'correct_text': text, 'original_text': text, 'wrong_ids': []}) + return rst + + +def proc_test_set(fp): + """ + 生成sighan15的测试集 + Args: + fp: + Returns: + """ + inputs = dict() + with open(os.path.join(fp, 'SIGHAN15_CSC_TestInput.txt'), 'r', encoding='utf-8') as f: + for line in f: + pid = line[5:14] + text = line[16:].strip() + inputs[pid] = text + + rst = [] + with open(os.path.join(fp, 'SIGHAN15_CSC_TestTruth.txt'), 'r', encoding='utf-8') as f: + for line in f: + pid = line[0:9] + mistakes = line[11:].strip().split(', ') + if len(mistakes) <= 1: + text = traditional2simplified(inputs[pid]) + rst.append({'id': pid, + 'original_text': text, + 'wrong_ids': [], + 'correct_text': text}) + else: + wrong_ids = [] + original_text = inputs[pid] + cor_text = inputs[pid] + for i in range(len(mistakes) // 2): + idx = int(mistakes[2 * i]) - 1 + cor_char = mistakes[2 * i + 1] + wrong_ids.append(idx) + cor_text = f'{cor_text[:idx]}{cor_char}{cor_text[idx + 1:]}' + original_text = traditional2simplified(original_text) + cor_text = traditional2simplified(cor_text) + if len(original_text) != len(cor_text): + print('error line:\n', pid) + print(original_text) + print(cor_text) + continue + rst.append({'id': pid, + 'original_text': original_text, + 'wrong_ids': wrong_ids, + 'correct_text': cor_text}) + + return rst + + +def parse_cged_file(file_dir): + rst = [] + for fn in os.listdir(file_dir): + if fn.endswith('.xml'): + path = os.path.join(file_dir, fn) + print('Parse data from %s' % path) + + dom_tree = minidom.parse(path) + docs = dom_tree.documentElement.getElementsByTagName('DOC') + for doc in docs: + id = '' + text = '' + texts = doc.getElementsByTagName('TEXT') + for i in texts: + id = i.getAttribute('id') + # Input the text + text = i.childNodes[0].data.strip() + # Input the correct text + correction = doc.getElementsByTagName('CORRECTION')[0]. \ + childNodes[0].data.strip() + wrong_ids = [] + for error in doc.getElementsByTagName('ERROR'): + start_off = error.getAttribute('start_off') + end_off = error.getAttribute('end_off') + if start_off and end_off: + for i in range(int(start_off), int(end_off)+1): + wrong_ids.append(i) + source = text.strip() + target = correction.strip() + + pair = [source, target] + if pair not in rst: + rst.append({'id': id, + 'original_text': source, + 'wrong_ids': wrong_ids, + 'correct_text': target + }) + save_json(rst, os.path.join(pwd_path, 'output/cged.json')) + return rst + + +def main(): + # 注意:该训练样本较少,仅作为模型测试使用 + # parse_cged_file(os.path.join(pwd_path, '../data/cn/CGED/')) + sighan15_dir = os.path.join(pwd_path, '../data/cn/sighan_2015/') + rst_items = [] + test_lst = proc_test_set(sighan15_dir) + for item in read_data(sighan15_dir): + rst_items += proc_item(item) + + # 拆分训练与测试 + print('data_size:', len(rst_items)) + train_lst, dev_lst = train_test_split(rst_items, test_size=0.1, random_state=42) + save_json(train_lst, os.path.join(pwd_path, 'output/train.json')) + save_json(dev_lst, os.path.join(pwd_path, 'output/dev.json')) + save_json(test_lst, os.path.join(pwd_path, 'output/test.json')) + + +if __name__ == '__main__': + main() diff --git a/reader.py b/reader.py new file mode 100644 index 0000000..5e47d8d --- /dev/null +++ b/reader.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com) +@description: +""" +import os +import json +import torch +from torch.utils.data import Dataset +from transformers import BertTokenizerFast +from torch.utils.data import DataLoader + + +class DataCollator: + def __init__(self, tokenizer: BertTokenizerFast): + self.tokenizer = tokenizer + + def __call__(self, data): + ori_texts, cor_texts, wrong_idss = zip(*data) + encoded_texts = [self.tokenizer(t, return_offsets_mapping=True, add_special_tokens=False) for t in ori_texts] + max_len = max([len(t['input_ids']) for t in encoded_texts]) + 2 + det_labels = torch.zeros(len(ori_texts), max_len).long() + + for i, (encoded_text, wrong_ids) in enumerate(zip(encoded_texts, wrong_idss)): + off_mapping = encoded_text['offset_mapping'] + for idx in wrong_ids: + for j, (b, e) in enumerate(off_mapping): + if b <= idx < e: + # j+1是因为前面的 CLS token + det_labels[i, j + 1] = 1 + break + + return list(ori_texts), list(cor_texts), det_labels + + +class CscDataset(Dataset): + def __init__(self, file_path): + self.data = json.load(open(file_path, 'r', encoding='utf-8')) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.data[index]['original_text'], self.data[index]['correct_text'], self.data[index]['wrong_ids'] + + +def make_loaders(collate_fn, train_path='', valid_path='', test_path='', + batch_size=32, num_workers=4): + train_loader = None + if train_path and os.path.exists(train_path): + train_loader = DataLoader(CscDataset(train_path), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + collate_fn=collate_fn) + valid_loader = None + if valid_path and os.path.exists(valid_path): + valid_loader = DataLoader(CscDataset(valid_path), + batch_size=batch_size, + num_workers=num_workers, + collate_fn=collate_fn) + test_loader = None + if test_path and os.path.exists(test_path): + test_loader = DataLoader(CscDataset(test_path), + batch_size=batch_size, + num_workers=num_workers, + collate_fn=collate_fn) + return train_loader, valid_loader, test_loader diff --git a/rewrite.log b/rewrite.log new file mode 100644 index 0000000..83d0525 --- /dev/null +++ b/rewrite.log @@ -0,0 +1,131 @@ +2023-05-12 11:44:02,567 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses. + WARNING: This is a development server. Do not use it in a production deployment. +2023-05-12 11:44:02,567 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.116:16000/ (Press CTRL+C to quit) +2023-05-12 11:47:41,934 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses. + WARNING: This is a development server. Do not use it in a production deployment. +2023-05-12 11:47:41,935 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.116:16000/ (Press CTRL+C to quit) +2023-05-12 11:47:43,032 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.116 - - [12/May/2023 11:47:43] "POST /predict HTTP/1.1" 200 - +2023-05-12 11:48:12,651 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.116 - - [12/May/2023 11:48:12] "POST /predict HTTP/1.1" 200 - +2023-05-12 11:48:32,073 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.116 - - [12/May/2023 11:48:32] "POST /predict HTTP/1.1" 200 - +2023-05-12 11:49:22,664 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses. + WARNING: This is a development server. Do not use it in a production deployment. +2023-05-12 11:49:22,664 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.116:16000/ (Press CTRL+C to quit) +2023-05-12 11:49:25,857 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.116 - - [12/May/2023 11:49:25] "POST /predict HTTP/1.1" 200 - +2023-06-13 19:24:21,082 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses. + WARNING: This is a development server. Do not use it in a production deployment. +2023-06-13 19:24:21,083 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.115:16000/ (Press CTRL+C to quit) +2023-06-13 19:25:24,092 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [13/Jun/2023 19:25:24] "POST /predict HTTP/1.1" 200 - +2023-06-13 19:26:11,123 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [13/Jun/2023 19:26:11] "POST /predict HTTP/1.1" 200 - +2023-06-13 19:35:27,653 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [13/Jun/2023 19:35:27] "POST /predict HTTP/1.1" 200 - +2023-06-13 19:37:38,520 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses. + WARNING: This is a development server. Do not use it in a production deployment. +2023-06-13 19:37:38,520 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.115:16000/ (Press CTRL+C to quit) +2023-06-13 19:44:29,271 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py[line:1892] - ERROR: Exception on /predict [POST] +Traceback (most recent call last): + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\tokenization_utils_base.py", line 248, in __getattr__ + return self.data[item] +KeyError: 'size' + +During handling of the above exception, another exception occurred: + +Traceback (most recent call last): + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 2447, in wsgi_app + response = self.full_dispatch_request() + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1952, in full_dispatch_request + rv = self.handle_user_exception(e) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1821, in handle_user_exception + reraise(exc_type, exc_value, tb) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\_compat.py", line 39, in reraise + raise value + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1950, in full_dispatch_request + rv = self.dispatch_request() + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1936, in dispatch_request + return self.view_functions[rule.endpoint](**req.view_args) + File "E:/pycharm_workspace/macbert/flask_macbert.py", line 98, in handle_query + return_list = main(texts) + File "E:/pycharm_workspace/macbert/flask_macbert.py", line 79, in main + outputs = model(a) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\torch\nn\modules\module.py", line 1120, in _call_impl + result = forward_call(*input, **kwargs) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\models\bert\modeling_bert.py", line 1345, in forward + return_dict=return_dict, + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\torch\nn\modules\module.py", line 1120, in _call_impl + result = forward_call(*input, **kwargs) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\models\bert\modeling_bert.py", line 944, in forward + input_shape = input_ids.size() + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\tokenization_utils_base.py", line 250, in __getattr__ + raise AttributeError +AttributeError +2023-06-13 19:44:29,291 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [13/Jun/2023 19:44:29] "POST /predict HTTP/1.1" 500 - +2023-06-13 19:45:26,160 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses. + WARNING: This is a development server. Do not use it in a production deployment. +2023-06-13 19:45:26,161 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.115:16000/ (Press CTRL+C to quit) +2023-06-13 19:45:28,616 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py[line:1892] - ERROR: Exception on /predict [POST] +Traceback (most recent call last): + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\tokenization_utils_base.py", line 248, in __getattr__ + return self.data[item] +KeyError: 'size' + +During handling of the above exception, another exception occurred: + +Traceback (most recent call last): + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 2447, in wsgi_app + response = self.full_dispatch_request() + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1952, in full_dispatch_request + rv = self.handle_user_exception(e) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1821, in handle_user_exception + reraise(exc_type, exc_value, tb) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\_compat.py", line 39, in reraise + raise value + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1950, in full_dispatch_request + rv = self.dispatch_request() + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1936, in dispatch_request + return self.view_functions[rule.endpoint](**req.view_args) + File "E:/pycharm_workspace/macbert/flask_macbert.py", line 98, in handle_query + return_list = main(texts) + File "E:/pycharm_workspace/macbert/flask_macbert.py", line 79, in main + outputs = model(ids) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\torch\nn\modules\module.py", line 1120, in _call_impl + result = forward_call(*input, **kwargs) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\models\bert\modeling_bert.py", line 1345, in forward + return_dict=return_dict, + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\torch\nn\modules\module.py", line 1120, in _call_impl + result = forward_call(*input, **kwargs) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\models\bert\modeling_bert.py", line 944, in forward + input_shape = input_ids.size() + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\tokenization_utils_base.py", line 250, in __getattr__ + raise AttributeError +AttributeError +2023-06-13 19:45:28,623 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [13/Jun/2023 19:45:28] "POST /predict HTTP/1.1" 500 - +2023-06-13 19:45:58,576 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses. + WARNING: This is a development server. Do not use it in a production deployment. +2023-06-13 19:45:58,577 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.115:16000/ (Press CTRL+C to quit) +2023-06-13 19:50:29,338 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses. + WARNING: This is a development server. Do not use it in a production deployment. +2023-06-13 19:50:29,339 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.115:16000/ (Press CTRL+C to quit) +2023-06-13 19:50:32,134 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py[line:1892] - ERROR: Exception on /predict [POST] +Traceback (most recent call last): + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 2447, in wsgi_app + response = self.full_dispatch_request() + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1952, in full_dispatch_request + rv = self.handle_user_exception(e) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1821, in handle_user_exception + reraise(exc_type, exc_value, tb) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\_compat.py", line 39, in reraise + raise value + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1950, in full_dispatch_request + rv = self.dispatch_request() + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1936, in dispatch_request + return self.view_functions[rule.endpoint](**req.view_args) + File "E:/pycharm_workspace/macbert/flask_macbert.py", line 98, in handle_query + return_list = main(texts) + File "E:/pycharm_workspace/macbert/flask_macbert.py", line 79, in main + outputs = model.run(None, ids_input) + File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\torch\nn\modules\module.py", line 1178, in __getattr__ + type(self).__name__, name)) +AttributeError: 'BertForMaskedLM' object has no attribute 'run' +2023-06-13 19:50:32,146 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [13/Jun/2023 19:50:32] "POST /predict HTTP/1.1" 500 - +2023-06-14 14:07:51,042 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses. + WARNING: This is a development server. Do not use it in a production deployment. +2023-06-14 14:07:51,047 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.115:16000/ (Press CTRL+C to quit) +2023-06-14 14:07:54,329 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [14/Jun/2023 14:07:54] "POST /predict HTTP/1.1" 200 - diff --git a/softmaskedbert4csc.py b/softmaskedbert4csc.py new file mode 100644 index 0000000..46d4e6b --- /dev/null +++ b/softmaskedbert4csc.py @@ -0,0 +1,151 @@ +""" +@Time : 2021-01-21 12:00:59 +@File : modeling_soft_masked_bert.py +@Author : Abtion +@Email : abtion{at}outlook.com +""" +from abc import ABC +from collections import OrderedDict +import transformers as tfs +import torch +from torch import nn +from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertOnlyMLMHead +from transformers.modeling_utils import ModuleUtilsMixin +from pycorrector.macbert.base_model import CscTrainingModel + + +class DetectionNetwork(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.gru = nn.GRU( + self.config.hidden_size, + self.config.hidden_size // 2, + num_layers=2, + batch_first=True, + dropout=self.config.hidden_dropout_prob, + bidirectional=True, + ) + self.sigmoid = nn.Sigmoid() + self.linear = nn.Linear(self.config.hidden_size, 1) + + def forward(self, hidden_states): + out, _ = self.gru(hidden_states) + prob = self.linear(out) + prob = self.sigmoid(prob) + return prob + + +class CorrectionNetwork(torch.nn.Module, ModuleUtilsMixin): + def __init__(self, config, tokenizer, device): + super().__init__() + self.config = config + self.tokenizer = tokenizer + self.embeddings = BertEmbeddings(self.config) + self.bert = BertEncoder(self.config) + self.mask_token_id = self.tokenizer.mask_token_id + self.cls = BertOnlyMLMHead(self.config) + self._device = device + + def forward(self, texts, prob, embed=None, cor_labels=None, residual_connection=False): + if cor_labels is not None: + text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids'] + # torch的cross entropy loss 会忽略-100的label + text_labels[text_labels == 0] = -100 + text_labels = text_labels.to(self._device) + else: + text_labels = None + encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt') + encoded_texts.to(self._device) + if embed is None: + embed = self.embeddings(input_ids=encoded_texts['input_ids'], + token_type_ids=encoded_texts['token_type_ids']) + # 此处较原文有一定改动,做此改动意在完整保留type_ids及position_ids的embedding。 + mask_embed = self.embeddings(torch.ones_like(prob.squeeze(-1)).long() * self.mask_token_id).detach() + # 此处为原文实现 + # mask_embed = self.embeddings(torch.tensor([[self.mask_token_id]], device=self._device)).detach() + cor_embed = prob * mask_embed + (1 - prob) * embed + + input_shape = encoded_texts['input_ids'].size() + device = encoded_texts['input_ids'].device + + extended_attention_mask = self.get_extended_attention_mask(encoded_texts['attention_mask'], + input_shape, device) + head_mask = self.get_head_mask(None, self.config.num_hidden_layers) + encoder_outputs = self.bert( + cor_embed, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + return_dict=False, + ) + sequence_output = encoder_outputs[0] + + sequence_output = sequence_output + embed if residual_connection else sequence_output + prediction_scores = self.cls(sequence_output) + out = (prediction_scores, sequence_output) + + # Masked language modeling softmax layer + if text_labels is not None: + loss_fct = nn.CrossEntropyLoss() # -100 index = padding token + cor_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), text_labels.view(-1)) + out = (cor_loss,) + out + return out + + def load_from_transformers_state_dict(self, gen_fp): + state_dict = OrderedDict() + gen_state_dict = tfs.AutoModelForMaskedLM.from_pretrained(gen_fp).state_dict() + for k, v in gen_state_dict.items(): + name = k + if name.startswith('bert'): + name = name[5:] + if name.startswith('encoder'): + name = f'corrector.{name[8:]}' + if 'gamma' in name: + name = name.replace('gamma', 'weight') + if 'beta' in name: + name = name.replace('beta', 'bias') + state_dict[name] = v + self.load_state_dict(state_dict, strict=False) + + +class SoftMaskedBert4Csc(CscTrainingModel, ABC): + def __init__(self, cfg, tokenizer): + super().__init__(cfg) + self.cfg = cfg + self.config = tfs.AutoConfig.from_pretrained(cfg.MODEL.BERT_CKPT) + self.detector = DetectionNetwork(self.config) + self.tokenizer = tokenizer + self.corrector = CorrectionNetwork(self.config, tokenizer, cfg.MODEL.DEVICE) + self.corrector.load_from_transformers_state_dict(self.cfg.MODEL.BERT_CKPT) + self._device = cfg.MODEL.DEVICE + + def forward(self, texts, cor_labels=None, det_labels=None): + encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt') + encoded_texts.to(self._device) + embed = self.corrector.embeddings(input_ids=encoded_texts['input_ids'], + token_type_ids=encoded_texts['token_type_ids']) + prob = self.detector(embed) + cor_out = self.corrector(texts, prob, embed, cor_labels, residual_connection=True) + + if det_labels is not None: + det_loss_fct = nn.BCELoss() + # pad部分不计算损失 + active_loss = encoded_texts['attention_mask'].view(-1, prob.shape[1]) == 1 + active_probs = prob.view(-1, prob.shape[1])[active_loss] + active_labels = det_labels[active_loss] + det_loss = det_loss_fct(active_probs, active_labels.float()) + outputs = (det_loss, cor_out[0], prob.squeeze(-1)) + cor_out[1:] + else: + outputs = (prob.squeeze(-1),) + cor_out + + return outputs + + def load_from_transformers_state_dict(self, gen_fp): + """ + 从transformers加载预训练权重 + :param gen_fp: + :return: + """ + self.corrector.load_from_transformers_state_dict(gen_fp) diff --git a/train.py b/train.py new file mode 100644 index 0000000..c8648be --- /dev/null +++ b/train.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com) +@description: +""" +import os +import sys +import torch +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint +from transformers import BertTokenizerFast, BertForMaskedLM +import argparse +from collections import OrderedDict +from loguru import logger + +sys.path.append('../..') + +from pycorrector.macbert.reader import make_loaders, DataCollator +from pycorrector.macbert.macbert4csc import MacBert4Csc +from pycorrector.macbert.softmaskedbert4csc import SoftMaskedBert4Csc +from pycorrector.macbert import preprocess +from pycorrector.macbert.defaults import _C as cfg + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" +os.environ["TOKENIZERS_PARALLELISM"] = "FALSE" + + +def args_parse(config_file=''): + parser = argparse.ArgumentParser(description="csc") + parser.add_argument( + "--config_file", default="train_macbert4csc.yml", help="path to config file", type=str + ) + parser.add_argument("--opts", help="Modify config options using the command-line key value", default=[], + nargs=argparse.REMAINDER) + + args = parser.parse_args() + + config_file = args.config_file or config_file + cfg.merge_from_file(config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + + logger.info(args) + + if config_file != '': + logger.info("Loaded configuration file {}".format(config_file)) + with open(config_file, 'r') as cf: + config_str = "\n" + cf.read() + logger.info(config_str) + + logger.info("Running with config:\n{}".format(cfg)) + return cfg + + +def main(): + cfg = args_parse() + + # 如果不存在训练文件则先处理数据 + if not os.path.exists(cfg.DATASETS.TRAIN): + logger.debug('preprocess data') + preprocess.main() + logger.info(f'load model, model arch: {cfg.MODEL.NAME}') + tokenizer = BertTokenizerFast.from_pretrained(cfg.MODEL.BERT_CKPT) + collator = DataCollator(tokenizer=tokenizer) + # 加载数据 + train_loader, valid_loader, test_loader = make_loaders(collator, train_path=cfg.DATASETS.TRAIN, + valid_path=cfg.DATASETS.VALID, test_path=cfg.DATASETS.TEST, + batch_size=cfg.SOLVER.BATCH_SIZE, num_workers=4) + if cfg.MODEL.NAME == 'softmaskedbert4csc': + model = SoftMaskedBert4Csc(cfg, tokenizer) + elif cfg.MODEL.NAME == 'macbert4csc': + model = MacBert4Csc(cfg, tokenizer) + else: + raise ValueError("model not found.") + # 加载之前保存的模型,继续训练 + if cfg.MODEL.WEIGHTS and os.path.exists(cfg.MODEL.WEIGHTS): + model.load_from_checkpoint(checkpoint_path=cfg.MODEL.WEIGHTS, cfg=cfg, map_location=device, tokenizer=tokenizer) + # 配置模型保存参数 + os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) + ckpt_callback = ModelCheckpoint( + monitor='val_loss', + dirpath=cfg.OUTPUT_DIR, + filename='{epoch:02d}-{val_loss:.2f}', + save_top_k=1, + mode='min' + ) + # 训练模型 + logger.info('train model ...') + trainer = pl.Trainer(max_epochs=cfg.SOLVER.MAX_EPOCHS, + gpus=None if device == torch.device('cpu') else cfg.MODEL.GPU_IDS, + accumulate_grad_batches=cfg.SOLVER.ACCUMULATE_GRAD_BATCHES, + callbacks=[ckpt_callback]) + # 进行训练 + # train_loader中有数据 + torch.autograd.set_detect_anomaly(True) + if 'train' in cfg.MODE and train_loader and len(train_loader) > 0: + if valid_loader and len(valid_loader) > 0: + trainer.fit(model, train_loader, valid_loader) + else: + trainer.fit(model, train_loader) + logger.info('train model done.') + # 模型转为transformers可加载 + if ckpt_callback and len(ckpt_callback.best_model_path) > 0: + ckpt_path = ckpt_callback.best_model_path + elif cfg.MODEL.WEIGHTS and os.path.exists(cfg.MODEL.WEIGHTS): + ckpt_path = cfg.MODEL.WEIGHTS + else: + ckpt_path = '' + logger.info(f'ckpt_path: {ckpt_path}') + if ckpt_path and os.path.exists(ckpt_path): + model.load_state_dict(torch.load(ckpt_path)['state_dict']) + # 先保存原始transformer bert model + tokenizer.save_pretrained(cfg.OUTPUT_DIR) + bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT) + bert.save_pretrained(cfg.OUTPUT_DIR) + state_dict = torch.load(ckpt_path)['state_dict'] + new_state_dict = OrderedDict() + if cfg.MODEL.NAME in ['macbert4csc']: + for k, v in state_dict.items(): + if k.startswith('bert.'): + new_state_dict[k[5:]] = v + else: + new_state_dict = state_dict + # 再保存finetune训练后的模型文件,替换原始的pytorch_model.bin + torch.save(new_state_dict, os.path.join(cfg.OUTPUT_DIR, 'pytorch_model.bin')) + # 进行测试的逻辑同训练 + if 'test' in cfg.MODE and test_loader and len(test_loader) > 0: + trainer.test(model, test_loader) + + +if __name__ == '__main__': + main() diff --git a/train_macbert4csc.yml b/train_macbert4csc.yml new file mode 100644 index 0000000..e7fc368 --- /dev/null +++ b/train_macbert4csc.yml @@ -0,0 +1,24 @@ +MODEL: + BERT_CKPT: "hfl/chinese-macbert-base" + DEVICE: "cuda" + NAME: "macbert4csc" + GPU_IDS: [0] + # [loss_coefficient] + HYPER_PARAMS: [0.3] + #WEIGHTS: "output/macbert4csc/epoch=6-val_loss=0.07.ckpt" + WEIGHTS: "" + +DATASETS: + TRAIN: "output/train.json" + VALID: "output/dev.json" + TEST: "output/test.json" + +SOLVER: + BASE_LR: 5e-5 + WEIGHT_DECAY: 0.01 + BATCH_SIZE: 32 + MAX_EPOCHS: 10 + ACCUMULATE_GRAD_BATCHES: 4 + +OUTPUT_DIR: "output/macbert4csc" +MODE: ["train", "test"] diff --git a/train_softmaskedbert4csc.yml b/train_softmaskedbert4csc.yml new file mode 100644 index 0000000..ac225ed --- /dev/null +++ b/train_softmaskedbert4csc.yml @@ -0,0 +1,24 @@ +MODEL: + BERT_CKPT: "bert-base-chinese" + DEVICE: "cuda" + NAME: "softmaskedbert4csc" + GPU_IDS: [0] + # [loss_coefficient] + HYPER_PARAMS: [0.8] + #WEIGHTS: "output/softmaskedbert4csc/epoch=2-val_loss=0.07.ckpt" + WEIGHTS: "" + +DATASETS: + TRAIN: "output/train.json" + VALID: "output/dev.json" + TEST: "output/test.json" + +SOLVER: + BASE_LR: 0.0001 + WEIGHT_DECAY: 5e-8 + BATCH_SIZE: 32 + MAX_EPOCHS: 10 + ACCUMULATE_GRAD_BATCHES: 4 + +OUTPUT_DIR: "output/softmaskedbert4csc" +MODE: ["train", "test"] \ No newline at end of file