Browse Source

第一次提交

master
majiahui@haimaqingfan.com 2 years ago
commit
c7f4f8ff4f
  1. 9
      .idea/.gitignore
  2. 28
      .idea/deployment.xml
  3. 6
      .idea/inspectionProfiles/profiles_settings.xml
  4. 12
      .idea/macbert.iml
  5. 4
      .idea/misc.xml
  6. 8
      .idea/modules.xml
  7. 7
      .idea/vcs.xml
  8. 237
      README.md
  9. 0
      __init__.py
  10. 191
      base_model.py
  11. 94
      ceshifenli.py
  12. 43
      correct_demo.py
  13. 114
      defaults.py
  14. 255
      evaluate_util.py
  15. 159
      flask_macbert.py
  16. 116
      infer.py
  17. 178
      lr_scheduler.py
  18. 50
      macbert4csc.py
  19. 166
      macbert_corrector.py
  20. 46
      predict.py
  21. 200
      preprocess.py
  22. 68
      reader.py
  23. 131
      rewrite.log
  24. 151
      softmaskedbert4csc.py
  25. 133
      train.py
  26. 24
      train_macbert4csc.yml
  27. 24
      train_softmaskedbert4csc.yml

9
.idea/.gitignore

@ -0,0 +1,9 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
/macbert4csc-base-chinese

28
.idea/deployment.xml

@ -0,0 +1,28 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="majiahui@104.244.89.190:27998">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="majiahui@192.168.31.145:22">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="majiahui@192.168.31.145:22 (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
</component>
</project>

6
.idea/inspectionProfiles/profiles_settings.xml

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

12
.idea/macbert.iml

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.6 (bertsum_nezha)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

4
.idea/misc.xml

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (bertsum_nezha)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/macbert.iml" filepath="$PROJECT_DIR$/.idea/macbert.iml" />
</modules>
</component>
</project>

7
.idea/vcs.xml

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
<mapping directory="$PROJECT_DIR$/macbert4csc-base-chinese" vcs="Git" />
</component>
</project>

237
README.md

@ -0,0 +1,237 @@
# MacBertMaskedLM For Correction
本项目是 MacBERT 改变网络结构的中文文本纠错模型,可支持 BERT 类模型为 backbone。
> "MacBERT shares the same pre-training tasks as BERT with several modifications." —— (Cui et al., Findings of the EMNLP 2020)
MacBert4csc 模型网络结构:
+ 本项目是 MacBERT 改变网络结构的中文文本纠错模型,可支持 BERT 类模型为 backbone。
+ 在通常 BERT 模型上进行了魔改,追加了一个全连接层作为错误检测即 [detection](https://github.com/shibing624/pycorrector/blob/c0f31222b7849c452cc1ec207c71e9954bd6ca08/pycorrector/macbert/macbert4csc.py#L18),
与 SoftMaskedBERT 模型不同点在于,本项目中的 MacBERT 中,只是利用 detection 层和 correction 层的 loss 加权得到最终的 loss。不像 SoftmaskedBERT 中需要利用 detection 层的置信概率来作为 correction 的输入权重。
![macbert_network](https://github.com/shibing624/pycorrector/blob/master/docs/git_image/macbert_network.jpg)
#### MacBERT 简介
MacBERT 全称为 MLM as correction BERT,其中 MLM 指的是 masked language model。
MacBERT 的模型网络结构上可以选择任意 BERT 类模型,其主要特征在于预训练时不同的 MLM task 设计:
+ 使用全词屏蔽 (wwm, whole-word masking) 以及 N-gram 屏蔽策略来选择 candidate tokens 进行屏蔽;
+ BERT 类模型通常使用 `[MASK]` 来屏蔽原词,而 MacBERT 使用第三方的同义词工具来为目标词生成近义词用于屏蔽原词,特别地,当原词没有近义词时,使用随机 n-gram 来屏蔽原词;
+ 和 BERT 类模型相似地,对于每个训练样本,输入中 80% 的词被替换成近义词(原为`[MASK]`)、10%的词替换为随机词,10%的词不变。
MLM as Correction Mask strategies:
![macbert_strategies](https://github.com/shibing624/pycorrector/blob/master/docs/git_image/macbert_mask_strategies.jpg)
## 使用说明
### 快速加载
#### pycorrector调用
example: [correct_demo.py](correct_demo.py)
```python
from pycorrector.macbert.macbert_corrector import MacBertCorrector
nlp = MacBertCorrector("shibing624/macbert4csc-base-chinese").macbert_correct
i = nlp('今天新情很好')
print(i)
```
#### transformers调用
当然,你也可使用官方的transformers库进行调用。
1.先pip安装transformers库:
```shell
pip install transformers>=4.1.1
```
2.使用以下示例执行:
```python
import operator
import torch
from transformers import BertTokenizerFast, BertForMaskedLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizerFast.from_pretrained("shibing624/macbert4csc-base-chinese")
model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese")
model.to(device)
texts = ["今天新情很好", "你找到你最喜欢的工作,我也很高心。"]
with torch.no_grad():
outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device))
def get_errors(corrected_text, origin_text):
sub_details = []
for i, ori_char in enumerate(origin_text):
if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
# add unk word
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
continue
if i >= len(corrected_text):
continue
if ori_char != corrected_text[i]:
if ori_char.lower() == corrected_text[i]:
# pass english upper char
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
continue
sub_details.append((ori_char, corrected_text[i], i, i + 1))
sub_details = sorted(sub_details, key=operator.itemgetter(2))
return corrected_text, sub_details
result = []
for ids, text in zip(outputs.logits, texts):
_text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
corrected_text = _text[:len(text)]
corrected_text, details = get_errors(corrected_text, text)
print(text, ' => ', corrected_text, details)
result.append((corrected_text, details))
print(result)
```
output:
```shell
今天新情很好 => 今天心情很好 [('新', '心', 2, 3)]
你找到你最喜欢的工作,我也很高心。 => 你找到你最喜欢的工作,我也很高兴。 [('心', '兴', 15, 16)]
```
模型文件组成:
```
macbert4csc-base-chinese
├── config.json
├── added_tokens.json
├── pytorch_model.bin
├── special_tokens_map.json
├── tokenizer_config.json
└── vocab.txt
```
## Evaluate
提供评估脚本[pycorrector/utils/eval.py](../utils/eval.py),该脚本有两个功能:
- 构建评估样本集:评估集[pycorrector/data/eval_corpus.json](../data/eval_corpus.json),
包括字粒度错误100条、词粒度错误100条、语法错误100条,正确句子200条。用户可以修改条数生成其他评估样本分布。
- 计算纠错准召率:采用保守计算方式,简单把纠错之后与正确句子完成匹配的视为正确,否则为错。
执行该评估脚本后,
`shibing624/macbert4csc-base-chinese` 模型在 corpus500 纠错效果评估如下:
- Sentence Level: acc:0.6560, precision:0.7797, recall:0.5919, f1:0.6730
规则方法(加入自定义混淆集)在corpus500纠错效果评估如下:
- Sentence Level: acc:0.6400, recall:0.5067
`shibing624/macbert4csc-base-chinese` 在 SIGHAN2015 测试集纠错效果评估如下:
- Char Level: precision:0.9372, recall:0.8640, f1:0.8991
- Sentence Level: precision:0.8264, recall:0.7366, f1:0.7789
由于训练使用的数据使用了 SIGHAN2015 的训练集(复现paper),在 SIGHAN2015 的测试集上达到SOTA水平。
#### 评估case
- run `python tests/macbert_corrector_test.py`
![result](../../docs/git_image/macbert_result.jpg)
在 SIGHAN2015 的测试集上达到了SOTA水平。
## 训练
### 安装依赖
```shell
pip install transformers>=4.1.1 pytorch-lightning==1.4.9 torch>=1.7.0 yacs
```
### 训练数据集
#### toy数据集(约1千条)
```shell
cd macbert
python preprocess.py
```
得到toy数据集文件:
```shell
macbert/output
|-- dev.json
|-- test.json
`-- train.json
```
#### SIGHAN+Wang271K中文纠错数据集
| 数据集 | 语料 | 下载链接 | 压缩包大小 |
| :------- | :--------- | :---------: | :---------: |
| **`SIGHAN+Wang271K中文纠错数据集`** | SIGHAN+Wang271K(27万条) | [百度网盘(密码01b9)](https://pan.baidu.com/s/1BV5tr9eONZCI0wERFvr0gQ) <br/> [shibing624/CSC](https://huggingface.co/datasets/shibing624/CSC)| 106M |
| **`原始SIGHAN数据集`** | SIGHAN13 14 15 | [官方csc.html](http://nlp.ee.ncu.edu.tw/resource/csc.html)| 339K |
| **`原始Wang271K数据集`** | Wang271K | [Automatic-Corpus-Generation dimmywang提供](https://github.com/wdimmy/Automatic-Corpus-Generation/blob/master/corpus/train.sgml)| 93M |
SIGHAN+Wang271K中文纠错数据集,数据格式:
```json
[
{
"id": "B2-4029-3",
"original_text": "晚间会听到嗓音,白天的时候大家都不会太在意,但是在睡觉的时候这嗓音成为大家的恶梦。",
"wrong_ids": [
5,
31
],
"correct_text": "晚间会听到噪音,白天的时候大家都不会太在意,但是在睡觉的时候这噪音成为大家的恶梦。"
}
]
```
下载`SIGHAN+Wang271K中文纠错数据集`,下载后新建output文件夹并放里面,文件位置同上。
#### 自有数据集
把自己数据集标注好,保存为跟训练样本集一样的json格式,然后加载模型继续训练即可。
1. 已有大量业务相关错误样本,主要标注错误位置(wrong_ids)和纠错后的句子(correct_text)
2. 没有现成的错误样本,可以手动写脚本生成错误样本(original_text),根据音似、形似等特征把正确句子的指定位置(wrong_ids)字符改为错字,附上
第三方同音字生成脚本[同音词替换](https://github.com/dongrixinyu/JioNLP/wiki/%E6%95%B0%E6%8D%AE%E5%A2%9E%E5%BC%BA-%E8%AF%B4%E6%98%8E%E6%96%87%E6%A1%A3#%E5%90%8C%E9%9F%B3%E8%AF%8D%E6%9B%BF%E6%8D%A2)
### 训练 MacBert4CSC 模型
```shell
python train.py
```
注意:MacBert4CSC模型只能处理对齐文本的纠错问题,不能处理多字、少字的错误,所以训练集original_text需要和correct_text长度一样。
否则会报错:“ValueError: Expected input batch_size (*A) to match target batch_size (*B).”
### 预测
- 方法一:直接加载保存的ckpt文件:
```shell
python infer.py
```
- 方法二:加载`pytorch_model.bin`文件:
把`output/macbert4csc`文件夹下以下模型文件复制到`~/.pycorrector/datasets/macbert_models/chinese_finetuned_correction`目录下,
就可以像上面`快速加载`使用pycorrector或者transformers调用。
```shell
output
└── macbert4csc
├── config.json
├── pytorch_model.bin
├── special_tokens_map.json
├── tokenizer_config.json
└── vocab.txt
```
demo示例[macbert_corrector.py](macbert_corrector.py):
```
python3 macbert_corrector.py
```
### 训练 SoftMaskedBert4CSC 模型
```shell
python train.py --config_file train_softmaskedbert4csc.yml
```
# Reference
- [BertBasedCorrectionModels](https://github.com/gitabtion/BertBasedCorrectionModels)
- <div class="csl-entry">Cui, Y., Che, W., Liu, T., Qin, B., Wang, S., &#38; Hu, G. (2020). Revisiting Pre-Trained Models for Chinese Natural Language Processing. <i>Findings of the EMNLP</i>, 657–668. https://doi.org/10.18653/v1/2020.findings-emnlp.58</div> (The publication for [MacBERT](https://arxiv.org/pdf/2004.13922.pdf))

0
__init__.py

191
base_model.py

@ -0,0 +1,191 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
@description:
"""
import operator
from abc import ABC
from loguru import logger
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from pycorrector.macbert import lr_scheduler
from pycorrector.macbert.evaluate_util import compute_corrector_prf, compute_sentence_level_prf
class FocalLoss(nn.Module):
"""
Softmax and sigmoid focal loss.
copy from https://github.com/lonePatient/TorchBlocks
"""
def __init__(self, num_labels, activation_type='softmax', gamma=2.0, alpha=0.25, epsilon=1.e-9):
super(FocalLoss, self).__init__()
self.num_labels = num_labels
self.gamma = gamma
self.alpha = alpha
self.epsilon = epsilon
self.activation_type = activation_type
def forward(self, input, target):
"""
Args:
logits: model's output, shape of [batch_size, num_cls]
target: ground truth labels, shape of [batch_size]
Returns:
shape of [batch_size]
"""
if self.activation_type == 'softmax':
idx = target.view(-1, 1).long()
one_hot_key = torch.zeros(idx.size(0), self.num_labels, dtype=torch.float32, device=idx.device)
one_hot_key = one_hot_key.scatter_(1, idx, 1)
logits = torch.softmax(input, dim=-1)
loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
loss = loss.sum(1)
elif self.activation_type == 'sigmoid':
multi_hot_key = target
logits = torch.sigmoid(input)
zero_hot_key = 1 - multi_hot_key
loss = -self.alpha * multi_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()
loss += -(1 - self.alpha) * zero_hot_key * torch.pow(logits, self.gamma) * (1 - logits + self.epsilon).log()
return loss.mean()
def make_optimizer(cfg, model):
params = []
for key, value in model.named_parameters():
if not value.requires_grad:
continue
lr = cfg.SOLVER.BASE_LR
weight_decay = cfg.SOLVER.WEIGHT_DECAY
if "bias" in key:
lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
if cfg.SOLVER.OPTIMIZER_NAME == 'SGD':
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM)
else:
optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params)
return optimizer
def build_lr_scheduler(cfg, optimizer):
scheduler_args = {
"optimizer": optimizer,
# warmup options
"warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
"warmup_epochs": cfg.SOLVER.WARMUP_EPOCHS,
"warmup_method": cfg.SOLVER.WARMUP_METHOD,
# multi-step lr scheduler options
"milestones": cfg.SOLVER.STEPS,
"gamma": cfg.SOLVER.GAMMA,
# cosine annealing lr scheduler options
"max_iters": cfg.SOLVER.MAX_ITER,
"delay_iters": cfg.SOLVER.DELAY_ITERS,
"eta_min_lr": cfg.SOLVER.ETA_MIN_LR,
}
scheduler = getattr(lr_scheduler, cfg.SOLVER.SCHED)(**scheduler_args)
return {'scheduler': scheduler, 'interval': cfg.SOLVER.INTERVAL}
class BaseTrainingEngine(pl.LightningModule):
def __init__(self, cfg, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cfg = cfg
def configure_optimizers(self):
optimizer = make_optimizer(self.cfg, self)
scheduler = build_lr_scheduler(self.cfg, optimizer)
return [optimizer], [scheduler]
def on_validation_epoch_start(self) -> None:
logger.info('Valid.')
def on_test_epoch_start(self) -> None:
logger.info('Testing...')
class CscTrainingModel(BaseTrainingEngine, ABC):
"""
用于CSC的BaseModel, 定义了训练及预测步骤
"""
def __init__(self, cfg, *args, **kwargs):
super().__init__(cfg, *args, **kwargs)
# loss weight
self.w = cfg.MODEL.HYPER_PARAMS[0]
def training_step(self, batch, batch_idx):
ori_text, cor_text, det_labels = batch
outputs = self.forward(ori_text, cor_text, det_labels)
loss = self.w * outputs[1] + (1 - self.w) * outputs[0]
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=len(ori_text))
return loss
def validation_step(self, batch, batch_idx):
ori_text, cor_text, det_labels = batch
outputs = self.forward(ori_text, cor_text, det_labels)
loss = self.w * outputs[1] + (1 - self.w) * outputs[0]
det_y_hat = (outputs[2] > 0.5).long()
cor_y_hat = torch.argmax((outputs[3]), dim=-1)
encoded_x = self.tokenizer(cor_text, padding=True, return_tensors='pt')
encoded_x.to(self._device)
cor_y = encoded_x['input_ids']
cor_y_hat *= encoded_x['attention_mask']
results = []
det_acc_labels = []
cor_acc_labels = []
for src, tgt, predict, det_predict, det_label in zip(ori_text, cor_y, cor_y_hat, det_y_hat, det_labels):
_src = self.tokenizer(src, add_special_tokens=False)['input_ids']
_tgt = tgt[1:len(_src) + 1].cpu().numpy().tolist()
_predict = predict[1:len(_src) + 1].cpu().numpy().tolist()
cor_acc_labels.append(1 if operator.eq(_tgt, _predict) else 0)
det_acc_labels.append(det_predict[1:len(_src) + 1].equal(det_label[1:len(_src) + 1]))
results.append((_src, _tgt, _predict,))
return loss.cpu().item(), det_acc_labels, cor_acc_labels, results
def validation_epoch_end(self, outputs) -> None:
det_acc_labels = []
cor_acc_labels = []
results = []
for out in outputs:
det_acc_labels += out[1]
cor_acc_labels += out[2]
results += out[3]
loss = np.mean([out[0] for out in outputs])
self.log('val_loss', loss)
logger.info(f'loss: {loss}')
logger.info(f'Detection:\n'
f'acc: {np.mean(det_acc_labels):.4f}')
logger.info(f'Correction:\n'
f'acc: {np.mean(cor_acc_labels):.4f}')
compute_corrector_prf(results, logger)
compute_sentence_level_prf(results, logger)
def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
def test_epoch_end(self, outputs) -> None:
logger.info('Test.')
self.validation_epoch_end(outputs)
def predict(self, texts):
inputs = self.tokenizer(texts, padding=True, return_tensors='pt')
inputs.to(self.cfg.MODEL.DEVICE)
with torch.no_grad():
outputs = self.forward(texts)
y_hat = torch.argmax(outputs[1], dim=-1)
expand_text_lens = torch.sum(inputs['attention_mask'], dim=-1) - 1
rst = []
for t_len, _y_hat in zip(expand_text_lens, y_hat):
rst.append(self.tokenizer.decode(_y_hat[1:t_len]).replace(' ', ''))
return rst

94
ceshifenli.py

@ -0,0 +1,94 @@
import unicodedata
def is_chinese(char):
if 'CJK' in unicodedata.name(char):
return True
else:
return False
a = "ab1我们12是一个"
b = [""] *len(a)
last_post = False
c = []
for i, d in enumerate(a):
bool_ = is_chinese(d)
if bool_ == False:
b[i] = d
last_post = False
else:
if last_post == False:
c.append([(i,d)])
else:
c[-1].append((i,d))
last_post = True
print(c)
print(b)
d = []
for i in c:
d.append("".join([j[1] for j in i]))
print(d)
e = d
f = ""
for i in e:
f += i
f_list = list(f)
print(f_list)
for i,d in enumerate(b):
if d == "":
zi = f_list.pop(0)
print(zi)
b[i] = zi
print(b)
class SentenceUlit:
def __init__(self,sentence):
self.sentence = sentence
self.sentence_list = [""] * len(sentence)
self.last_post = False
self.sentence_batch = []
self.pre_ulit()
self.inf_sentence_batch_str = ""
def is_chinese(self, char):
if 'CJK' in unicodedata.name(char):
return True
else:
return False
def pre_ulit(self):
for i, d in enumerate(self.sentence):
bool_ = is_chinese(d)
if bool_ == False:
self.sentence_list[i] = d
self.last_post = False
else:
if self.last_post == False:
self.sentence_batch.append(d)
else:
self.sentence_batch[-1] += d
self.last_post = True
def inf_ulit(self, sen):
for i in sen:
self.inf_sentence_batch_str += i
self.inf_sentence_batch_srt_list = list(self.inf_sentence_batch_str)
for i, d in enumerate(self.sentence_list):
if d == "":
zi = self.inf_sentence_batch_srt_list.pop(0)
self.sentence_list[i] = zi
sen = SentenceUlit("ab1我们12是一个")
print(sen.sentence_batch)
print(sen.sentence_list)

43
correct_demo.py

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
"""
@Time : 2021-02-03 21:57:15
@File : correct_demo.py
@Author : Abtion
@Email : abtion{at}outlook.com
"""
import argparse
import sys
from pycorrector.macbert.macbert_corrector import MacBertCorrector
from pycorrector import config
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--macbert_model_dir", default=config.macbert_model_dir,
type=str,
help="MacBert pre-trained model dir")
args = parser.parse_args()
nlp = MacBertCorrector(args.macbert_model_dir).macbert_correct
i = nlp('今天新情很好')
print(i)
i = nlp('少先队员英该为老人让座')
print(i)
i = nlp('机器学习是人工智能领遇最能体现智能的一个分知。')
print(i)
i = nlp('机其学习是人工智能领遇最能体现智能的一个分知。')
print(i)
print(nlp('老是较书。'))
print(nlp('遇到一位很棒的奴生跟我聊天。'))
if __name__ == "__main__":
main()

114
defaults.py

@ -0,0 +1,114 @@
"""
@Time : 2021-01-21 10:37:36
@File : defaults.py
@Author : Abtion
@Email : abtion{at}outlook.com
"""
from yacs.config import CfgNode as CN
# -----------------------------------------------------------------------------
# Convention about Training / Test specific parameters
# -----------------------------------------------------------------------------
# Whenever an argument can be either used for training or for testing, the
# corresponding name will be post-fixed by a _TRAIN for a training parameter,
# or _TEST for a test-specific parameter.
# For example, the number of images during training will be
# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
# IMAGES_PER_BATCH_TEST
# -----------------------------------------------------------------------------
# Config definition
# -----------------------------------------------------------------------------
_C = CN()
_C.MODEL = CN()
_C.MODEL.DEVICE = "cpu"
_C.MODEL.GPU_IDS = [0]
_C.MODEL.NUM_CLASSES = 10
_C.MODEL.BERT_CKPT = 'bert-base-chinese'
_C.MODEL.NAME = ''
_C.MODEL.WEIGHTS = ''
_C.MODEL.HYPER_PARAMS = []
# -----------------------------------------------------------------------------
# INPUT
# -----------------------------------------------------------------------------
_C.INPUT = CN()
# Max length of input text.
_C.INPUT.MAX_LEN = 512
# -----------------------------------------------------------------------------
# Dataset
# -----------------------------------------------------------------------------
_C.DATASETS = CN()
# List of the dataset names for training, as present in paths_catalog.py
_C.DATASETS.TRAIN = ""
# List of the dataset names for validation, as present in paths_catalog.py
_C.DATASETS.VALID = ""
# List of the dataset names for testing, as present in paths_catalog.py
_C.DATASETS.TEST = ""
# -----------------------------------------------------------------------------
# DataLoader
# -----------------------------------------------------------------------------
_C.DATALOADER = CN()
# Number of data loading threads
_C.DATALOADER.NUM_WORKERS = 4
# ---------------------------------------------------------------------------- #
# Solver
# ---------------------------------------------------------------------------- #
_C.SOLVER = CN()
_C.SOLVER.OPTIMIZER_NAME = "AdamW"
_C.SOLVER.MAX_EPOCHS = 50
_C.SOLVER.BASE_LR = 0.001
_C.SOLVER.BIAS_LR_FACTOR = 2
_C.SOLVER.MOMENTUM = 0.9
_C.SOLVER.WEIGHT_DECAY = 0.0005
_C.SOLVER.WEIGHT_DECAY_BIAS = 0
_C.SOLVER.GAMMA = 0.9999
_C.SOLVER.STEPS = (10,)
_C.SOLVER.SCHED = "WarmupExponentialLR"
_C.SOLVER.WARMUP_FACTOR = 0.01
_C.SOLVER.WARMUP_ITERS = 2
_C.SOLVER.WARMUP_EPOCHS = 1024
_C.SOLVER.WARMUP_METHOD = "linear"
_C.SOLVER.DELAY_ITERS = 0
_C.SOLVER.ETA_MIN_LR = 3e-7
_C.SOLVER.MAX_ITER = 10
_C.SOLVER.INTERVAL = 'step'
_C.SOLVER.CHECKPOINT_PERIOD = 10
_C.SOLVER.LOG_PERIOD = 100
_C.SOLVER.ACCUMULATE_GRAD_BATCHES = 1
# Number of images per batch
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
# see 2 images per batch
_C.SOLVER.BATCH_SIZE = 16
_C.TEST = CN()
_C.TEST.BATCH_SIZE = 8
_C.TEST.CKPT_FN = ""
# ---------------------------------------------------------------------------- #
# Task specific
# ---------------------------------------------------------------------------- #
_C.TASK = CN()
_C.TASK.NAME = "CSC"
# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #
_C.OUTPUT_DIR = ""
_C.MODE = ['train', 'test']

255
evaluate_util.py

@ -0,0 +1,255 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
@description:
"""
def compute_corrector_prf(results, logger):
"""
copy from https://github.com/sunnyqiny/Confusionset-guided-Pointer-Networks-for-Chinese-Spelling-Check/blob/master/utils/evaluation_metrics.py
"""
TP = 0
FP = 0
FN = 0
all_predict_true_index = []
all_gold_index = []
for item in results:
src, tgt, predict = item
gold_index = []
each_true_index = []
for i in range(len(list(src))):
if src[i] == tgt[i]:
continue
else:
gold_index.append(i)
all_gold_index.append(gold_index)
predict_index = []
for i in range(len(list(src))):
if src[i] == predict[i]:
continue
else:
predict_index.append(i)
for i in predict_index:
if i in gold_index:
TP += 1
each_true_index.append(i)
else:
FP += 1
for i in gold_index:
if i in predict_index:
continue
else:
FN += 1
all_predict_true_index.append(each_true_index)
# For the detection Precision, Recall and F1
detection_precision = TP / (TP + FP) if (TP + FP) > 0 else 0
detection_recall = TP / (TP + FN) if (TP + FN) > 0 else 0
if detection_precision + detection_recall == 0:
detection_f1 = 0
else:
detection_f1 = 2 * (detection_precision * detection_recall) / (detection_precision + detection_recall)
logger.info(
"The detection result is precision={}, recall={} and F1={}".format(detection_precision, detection_recall,
detection_f1))
TP = 0
FP = 0
FN = 0
for i in range(len(all_predict_true_index)):
# we only detect those correctly detected location, which is a different from the common metrics since
# we wanna to see the precision improve by using the confusionset
if len(all_predict_true_index[i]) > 0:
predict_words = []
for j in all_predict_true_index[i]:
predict_words.append(results[i][2][j])
if results[i][1][j] == results[i][2][j]:
TP += 1
else:
FP += 1
for j in all_gold_index[i]:
if results[i][1][j] in predict_words:
continue
else:
FN += 1
# For the correction Precision, Recall and F1
correction_precision = TP / (TP + FP) if (TP + FP) > 0 else 0
correction_recall = TP / (TP + FN) if (TP + FN) > 0 else 0
if correction_precision + correction_recall == 0:
correction_f1 = 0
else:
correction_f1 = 2 * (correction_precision * correction_recall) / (correction_precision + correction_recall)
logger.info("The correction result is precision={}, recall={} and F1={}".format(correction_precision,
correction_recall,
correction_f1))
return detection_f1, correction_f1
def compute_sentence_level_prf(results, logger):
"""
自定义的句级prf设定需要纠错为正样本无需纠错为负样本
:param results:
:return:
"""
TP = 0.0
FP = 0.0
FN = 0.0
TN = 0.0
total_num = len(results)
for item in results:
src, tgt, predict = item
# 负样本
if src == tgt:
# 预测也为负
if tgt == predict:
TN += 1
# 预测为正
else:
FP += 1
# 正样本
else:
# 预测也为正
if tgt == predict:
TP += 1
# 预测为负
else:
FN += 1
acc = (TP + TN) / total_num
precision = TP / (TP + FP) if TP > 0 else 0.0
recall = TP / (TP + FN) if TP > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
logger.info(f'Sentence Level: acc:{acc:.6f}, precision:{precision:.6f}, recall:{recall:.6f}, f1:{f1:.6f}')
return acc, precision, recall, f1
def report_prf(tp, fp, fn, phase, logger=None, return_dict=False):
# For the detection Precision, Recall and F1
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
if precision + recall == 0:
f1_score = 0
else:
f1_score = 2 * (precision * recall) / (precision + recall)
if phase and logger:
logger.info(f"The {phase} result is: "
f"{precision:.4f}/{recall:.4f}/{f1_score:.4f} -->\n"
# f"precision={precision:.6f}, recall={recall:.6f} and F1={f1_score:.6f}\n"
f"support: TP={tp}, FP={fp}, FN={fn}")
if return_dict:
ret_dict = {
f'{phase}_p': precision,
f'{phase}_r': recall,
f'{phase}_f1': f1_score}
return ret_dict
return precision, recall, f1_score
def compute_corrector_prf_faspell(results, logger=None, strict=True):
"""
All-in-one measure function.
based on FASpell's measure script.
:param results: a list of (wrong, correct, predict, ...)
both token_ids or characters are fine for the script.
:param logger: take which logger to print logs.
:param strict: a more strict evaluation mode (all-char-detected/corrected)
References:
sentence-level PRF: https://github.com/iqiyi/
FASPell/blob/master/faspell.py
"""
corrected_char, wrong_char = 0, 0
corrected_sent, wrong_sent = 0, 0
true_corrected_char = 0
true_corrected_sent = 0
true_detected_char = 0
true_detected_sent = 0
accurate_detected_sent = 0
accurate_corrected_sent = 0
all_sent = 0
for item in results:
# wrong, correct, predict, d_tgt, d_predict = item
wrong, correct, predict = item[:3]
all_sent += 1
wrong_num = 0
corrected_num = 0
original_wrong_num = 0
true_detected_char_in_sentence = 0
for c, w, p in zip(correct, wrong, predict):
if c != p:
wrong_num += 1
if w != p:
corrected_num += 1
if c == p:
true_corrected_char += 1
if w != c:
true_detected_char += 1
true_detected_char_in_sentence += 1
if c != w:
original_wrong_num += 1
corrected_char += corrected_num
wrong_char += original_wrong_num
if original_wrong_num != 0:
wrong_sent += 1
if corrected_num != 0 and wrong_num == 0:
true_corrected_sent += 1
if corrected_num != 0:
corrected_sent += 1
if strict: # find out all faulty wordings' potisions
true_detected_flag = (true_detected_char_in_sentence == original_wrong_num \
and original_wrong_num != 0 \
and corrected_num == true_detected_char_in_sentence)
else: # think it has faulty wordings
true_detected_flag = (corrected_num != 0 and original_wrong_num != 0)
# if corrected_num != 0 and original_wrong_num != 0:
if true_detected_flag:
true_detected_sent += 1
if correct == predict:
accurate_corrected_sent += 1
if correct == predict or true_detected_flag:
accurate_detected_sent += 1
counts = { # TP, FP, TN for each level
'det_char_counts': [true_detected_char,
corrected_char - true_detected_char,
wrong_char - true_detected_char],
'cor_char_counts': [true_corrected_char,
corrected_char - true_corrected_char,
wrong_char - true_corrected_char],
'det_sent_counts': [true_detected_sent,
corrected_sent - true_detected_sent,
wrong_sent - true_detected_sent],
'cor_sent_counts': [true_corrected_sent,
corrected_sent - true_corrected_sent,
wrong_sent - true_corrected_sent],
'det_sent_acc': accurate_detected_sent / all_sent,
'cor_sent_acc': accurate_corrected_sent / all_sent,
'all_sent_count': all_sent,
}
details = {}
for phase in ['det_char', 'cor_char', 'det_sent', 'cor_sent']:
dic = report_prf(
*counts[f'{phase}_counts'],
phase=phase, logger=logger,
return_dict=True)
details.update(dic)
details.update(counts)
return details

159
flask_macbert.py

@ -0,0 +1,159 @@
import os
from flask import Flask, jsonify
from flask import request
import operator
import torch
from transformers import BertTokenizerFast, BertForMaskedLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import uuid
import json
from threading import Thread
import time
import re
import logging
import unicodedata
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
filename='rewrite.log',
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
# a是追加模式,默认如果不写的话,就是追加模式
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式
)
db_key_query = 'query'
batch_size = 32
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
import logging
pattern = r"[。]"
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”")
fuhao_end_sentence = ["", "", "", "", ""]
tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese")
model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese")
model.to(device)
def is_chinese(char):
if 'CJK' in unicodedata.name(char):
return True
else:
return False
class SentenceUlit:
def __init__(self, sentence):
self.sentence = sentence
self.sentence_list = [""] * len(sentence)
self.last_post = False
self.sentence_batch = []
self.pre_ulit()
self.inf_sentence_batch_str = ""
def is_chinese(self, char):
if 'CJK' in unicodedata.name(char):
return True
else:
return False
def pre_ulit(self):
for i, d in enumerate(self.sentence):
bool_ = is_chinese(d)
if bool_ == False:
self.sentence_list[i] = d
self.last_post = False
else:
if self.last_post == False:
self.sentence_batch.append(d)
else:
self.sentence_batch[-1] += d
self.last_post = True
def inf_ulit(self, sen):
for i in sen:
self.inf_sentence_batch_str += i
self.inf_sentence_batch_srt_list = list(self.inf_sentence_batch_str)
for i, d in enumerate(self.sentence_list):
if d == "":
zi = self.inf_sentence_batch_srt_list.pop(0)
self.sentence_list[i] = zi
class log:
def __init__(self):
pass
def log(*args, **kwargs):
format = '%Y/%m/%d-%H:%M:%S'
format_h = '%Y-%m-%d'
value = time.localtime(int(time.time()))
dt = time.strftime(format, value)
dt_log_file = time.strftime(format_h, value)
log_file = 'log_file/access-%s' % dt_log_file + ".log"
if not os.path.exists(log_file):
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
else:
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
def get_errors(corrected_text, origin_text):
sub_details = []
for i, ori_char in enumerate(origin_text):
if ori_char in [' ', '', '', '', '', '', '\n', '', '', '']:
# add unk word
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
continue
if i >= len(corrected_text):
continue
if ori_char != corrected_text[i]:
if ori_char.lower() == corrected_text[i]:
# pass english upper char
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
continue
sub_details.append((ori_char, corrected_text[i], i, i + 1))
sub_details = sorted(sub_details, key=operator.itemgetter(2))
return corrected_text, sub_details
def main(texts):
with torch.no_grad():
outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device))
result = []
print(outputs.logits)
for ids, text in zip(outputs.logits, texts):
_text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
corrected_text = _text[:len(text)]
print(corrected_text)
corrected_text, details = get_errors(corrected_text, text)
result.append({"old": text,
"new": corrected_text,
"re_pos": details})
return result
@app.route("/predict", methods=["POST"])
def handle_query():
print(request.remote_addr)
texts = request.json["texts"]
return_list = main(texts)
return_text = {"resilt": return_list, "probabilities": None, "status_code": 200}
return jsonify(return_text) # 返回结果
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
filename='rewrite.log',
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
# a是追加模式,默认如果不写的话,就是追加模式
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式
)
app.run(host="0.0.0.0", port=16000, threaded=True, debug=False)

116
infer.py

@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com), okcd00(okcd00@qq.com)
@description:
"""
import sys
import torch
import argparse
from transformers import BertTokenizerFast
from loguru import logger
sys.path.append('../..')
from pycorrector.macbert.macbert4csc import MacBert4Csc
from pycorrector.macbert.softmaskedbert4csc import SoftMaskedBert4Csc
from pycorrector.macbert.macbert_corrector import get_errors
from pycorrector.macbert.defaults import _C as cfg
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Inference:
def __init__(self, ckpt_path='output/macbert4csc/epoch=09-val_loss=0.01.ckpt',
vocab_path='output/macbert4csc/vocab.txt',
cfg_path='train_macbert4csc.yml'):
logger.debug("device: {}".format(device))
self.tokenizer = BertTokenizerFast.from_pretrained(vocab_path)
cfg.merge_from_file(cfg_path)
if 'macbert4csc' in cfg_path:
self.model = MacBert4Csc.load_from_checkpoint(checkpoint_path=ckpt_path,
cfg=cfg,
map_location=device,
tokenizer=self.tokenizer)
elif 'softmaskedbert4csc' in cfg_path:
self.model = SoftMaskedBert4Csc.load_from_checkpoint(checkpoint_path=ckpt_path,
cfg=cfg,
map_location=device,
tokenizer=self.tokenizer)
else:
raise ValueError("model not found.")
self.model.to(device)
self.model.eval()
def predict(self, sentence_list):
"""
文本纠错模型预测
Args:
sentence_list: list
输入文本列表
Returns: tuple
corrected_texts(list)
"""
is_str = False
if isinstance(sentence_list, str):
is_str = True
sentence_list = [sentence_list]
corrected_texts = self.model.predict(sentence_list)
if is_str:
return corrected_texts[0]
return corrected_texts
def predict_with_error_detail(self, sentence_list):
"""
文本纠错模型预测结果带错误位置信息
Args:
sentence_list: list
输入文本列表
Returns: tuple
corrected_texts(list), details(list)
"""
details = []
is_str = False
if isinstance(sentence_list, str):
is_str = True
sentence_list = [sentence_list]
corrected_texts = self.model.predict(sentence_list)
for corrected_text, text in zip(corrected_texts, sentence_list):
corrected_text, sub_details = get_errors(corrected_text, text)
details.append(sub_details)
if is_str:
return corrected_texts[0], details[0]
return corrected_texts, details
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="infer")
parser.add_argument("--ckpt_path", default="output/macbert4csc/epoch=09-val_loss=0.01.ckpt",
help="path to config file", type=str)
parser.add_argument("--vocab_path", default="output/macbert4csc/vocab.txt", help="path to config file", type=str)
parser.add_argument("--config_file", default="train_macbert4csc.yml", help="path to config file", type=str)
args = parser.parse_args()
m = Inference(args.ckpt_path, args.vocab_path, args.config_file)
inputs = [
'它的本领是呼风唤雨,因此能灭火防灾。狎鱼后面是獬豸。獬豸通常头上长着独角,有时又被称为独角羊。它很聪彗,而且明辨是非,象征着大公无私,又能镇压斜恶。',
'老是较书。',
'少先队 员因该 为老人让 坐',
'感谢等五分以后,碰到一位很棒的奴生跟我可聊。',
'遇到一位很棒的奴生跟我聊天。',
'遇到一位很美的女生跟我疗天。',
'他们只能有两个选择:接受降新或自动离职。',
'王天华开心得一直说话。',
'你说:“怎么办?”我怎么知道?',
]
outputs = m.predict(inputs)
for a, b in zip(inputs, outputs):
print('input :', a)
print('predict:', b)
print()
# 在sighan2015 test数据集评估模型
# macbert4csc Sentence Level: acc:0.7845, precision:0.8174, recall:0.7256, f1:0.7688, cost time:10.79 s
# softmaskedbert4csc Sentence Level: acc:0.6964, precision:0.8065, recall:0.5064, f1:0.6222, cost time:16.20 s
from pycorrector.utils.eval import eval_sighan2015_by_model
eval_sighan2015_by_model(m.predict_with_error_detail)

178
lr_scheduler.py

@ -0,0 +1,178 @@
"""
@Time : 2021-01-21 10:52:47
@File : lr_scheduler.py
@Author : Abtion
@Email : abtion{at}outlook.com
"""
import math
import warnings
from bisect import bisect_right
from typing import List
import torch
from torch.optim.lr_scheduler import _LRScheduler
__all__ = ["WarmupMultiStepLR", "WarmupCosineAnnealingLR"]
class WarmupMultiStepLR(_LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
milestones: List[int],
gamma: float = 0.1,
warmup_factor: float = 0.001,
warmup_epochs: int = 2,
warmup_method: str = "linear",
last_epoch: int = -1,
**kwargs,
):
if not list(milestones) == sorted(milestones):
raise ValueError(
"Milestones should be a list of" " increasing integers. Got {}", milestones
)
self.milestones = milestones
self.gamma = gamma
self.warmup_factor = warmup_factor
self.warmup_epochs = warmup_epochs
self.warmup_method = warmup_method
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor
)
return [
base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
for base_lr in self.base_lrs
]
def _compute_values(self) -> List[float]:
# The new interface
return self.get_lr()
class WarmupExponentialLR(_LRScheduler):
"""Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
gamma (float): Multiplicative factor of learning rate decay.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(self, optimizer, gamma, last_epoch=-1, warmup_epochs=2, warmup_factor=1.0 / 3, verbose=False,
**kwargs):
self.gamma = gamma
self.warmup_method = 'linear'
self.warmup_epochs = warmup_epochs
self.warmup_factor = warmup_factor
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor
)
if self.last_epoch <= self.warmup_epochs:
return [base_lr * warmup_factor
for base_lr in self.base_lrs]
return [group['lr'] * self.gamma
for group in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [base_lr * self.gamma ** self.last_epoch
for base_lr in self.base_lrs]
class WarmupCosineAnnealingLR(_LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
\cos(\frac{T_{cur}}{T_{max}}\pi))
When last_epoch=-1, sets initial lr as lr.
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
implements the cosine annealing part of SGDR, and not the restarts.
Args:
optimizer (Optimizer): Wrapped optimizer.
T_max (int): Maximum number of iterations.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
max_iters: int,
delay_iters: int = 0,
eta_min_lr: int = 0,
warmup_factor: float = 0.001,
warmup_epochs: int = 2,
warmup_method: str = "linear",
last_epoch=-1,
**kwargs
):
self.max_iters = max_iters
self.delay_iters = delay_iters
self.eta_min_lr = eta_min_lr
self.warmup_factor = warmup_factor
self.warmup_epochs = warmup_epochs
self.warmup_method = warmup_method
assert self.delay_iters >= self.warmup_epochs, "Scheduler delay iters must be larger than warmup iters"
super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
if self.last_epoch <= self.warmup_epochs:
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor,
)
return [
base_lr * warmup_factor for base_lr in self.base_lrs
]
elif self.last_epoch <= self.delay_iters:
return self.base_lrs
else:
return [
self.eta_min_lr + (base_lr - self.eta_min_lr) *
(1 + math.cos(
math.pi * (self.last_epoch - self.delay_iters) / (self.max_iters - self.delay_iters))) / 2
for base_lr in self.base_lrs]
def _get_warmup_factor_at_iter(
method: str, iter: int, warmup_iters: int, warmup_factor: float
) -> float:
"""
Return the learning rate warmup factor at a specific iteration.
See https://arxiv.org/abs/1706.02677 for more details.
Args:
method (str): warmup method; either "constant" or "linear".
iter (int): iteration at which to calculate the warmup factor.
warmup_iters (int): the number of warmup iterations.
warmup_factor (float): the base warmup factor (the meaning changes according
to the method used).
Returns:
float: the effective warmup factor at the given iteration.
"""
if iter >= warmup_iters:
return 1.0
if method == "constant":
return warmup_factor
elif method == "linear":
alpha = iter / warmup_iters
return warmup_factor * (1 - alpha) + alpha
else:
raise ValueError("Unknown warmup method: {}".format(method))

50
macbert4csc.py

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
@description:
"""
from abc import ABC
import torch.nn as nn
from transformers import BertForMaskedLM
from pycorrector.macbert.base_model import CscTrainingModel, FocalLoss
class MacBert4Csc(CscTrainingModel, ABC):
def __init__(self, cfg, tokenizer):
super().__init__(cfg)
self.cfg = cfg
self.bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT)
self.detection = nn.Linear(self.bert.config.hidden_size, 1)
self.sigmoid = nn.Sigmoid()
self.tokenizer = tokenizer
def forward(self, texts, cor_labels=None, det_labels=None):
if cor_labels:
text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids']
text_labels[text_labels == 0] = -100 # -100计算损失时会忽略
text_labels = text_labels.to(self.device)
else:
text_labels = None
encoded_text = self.tokenizer(texts, padding=True, return_tensors='pt')
encoded_text.to(self.device)
bert_outputs = self.bert(**encoded_text, labels=text_labels, return_dict=True, output_hidden_states=True)
# 检错概率
prob = self.detection(bert_outputs.hidden_states[-1])
if text_labels is None:
# 检错输出,纠错输出
outputs = (prob, bert_outputs.logits)
else:
det_loss_fct = FocalLoss(num_labels=None, activation_type='sigmoid')
# pad部分不计算损失
active_loss = encoded_text['attention_mask'].view(-1, prob.shape[1]) == 1
active_probs = prob.view(-1, prob.shape[1])[active_loss]
active_labels = det_labels[active_loss]
det_loss = det_loss_fct(active_probs, active_labels.float())
# 检错loss,纠错loss,检错输出,纠错输出
outputs = (det_loss,
bert_outputs.loss,
self.sigmoid(prob).squeeze(-1),
bert_outputs.logits)
return outputs

166
macbert_corrector.py

@ -0,0 +1,166 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
@description:
"""
import operator
import sys
import time
import os
from transformers import BertTokenizerFast, BertForMaskedLM
import torch
from typing import List
from loguru import logger
sys.path.append('../..')
from pycorrector import config
from pycorrector.utils.tokenizer import split_text_by_maxlen
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
unk_tokens = [' ', '', '', '', '', '\n', '', '', '', '\t', '֍', '', '']
def get_errors(corrected_text, origin_text):
sub_details = []
for i, ori_char in enumerate(origin_text):
if i >= len(corrected_text):
break
if ori_char in unk_tokens:
# deal with unk word
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
continue
if ori_char != corrected_text[i]:
if ori_char.lower() == corrected_text[i]:
# pass english upper char
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
continue
sub_details.append((ori_char, corrected_text[i], i, i + 1))
sub_details = sorted(sub_details, key=operator.itemgetter(2))
return corrected_text, sub_details
class MacBertCorrector(object):
def __init__(self, model_dir=config.macbert_model_dir):
self.name = 'macbert_corrector'
t1 = time.time()
bin_path = os.path.join(model_dir, 'pytorch_model.bin')
if not os.path.exists(bin_path):
model_dir = "shibing624/macbert4csc-base-chinese"
logger.warning(f'local model {bin_path} not exists, use default HF model {model_dir}')
self.tokenizer = BertTokenizerFast.from_pretrained(model_dir)
self.model = BertForMaskedLM.from_pretrained(model_dir)
self.model.to(device)
logger.debug("Use device: {}".format(device))
logger.debug('Loaded macbert4csc model: %s, spend: %.3f s.' % (model_dir, time.time() - t1))
def macbert_correct(self, text, threshold=0.7, verbose=False):
"""
句子纠错
:param text: 句子文本
:param threshold: 阈值
:param verbose: 是否打印详细信息
:return: corrected_text, list[list], [error_word, correct_word, begin_pos, end_pos]
"""
text_new = ''
details = []
# 长句切分为短句
blocks = split_text_by_maxlen(text, maxlen=128)
block_texts = [block[0] for block in blocks]
inputs = self.tokenizer(block_texts, padding=True, return_tensors='pt').to(device)
with torch.no_grad():
outputs = self.model(**inputs)
for ids, (text, idx) in zip(outputs.logits, blocks):
decode_tokens_new = self.tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).split(' ')
decode_tokens_old = self.tokenizer.decode(inputs['input_ids'][idx], skip_special_tokens=True).split(' ')
if len(decode_tokens_new) != len(decode_tokens_old):
continue
probs = torch.max(torch.softmax(ids, dim=-1), dim=-1)[0].cpu().numpy()
decode_tokens = ''
for i in range(len(decode_tokens_old)):
if probs[i + 1] >= threshold:
if verbose:
logger.debug(
f"word: {decode_tokens_old[i]}, prob: {probs[i + 1]}, new word: {decode_tokens_new[i]}")
decode_tokens += decode_tokens_new[i]
else:
decode_tokens += decode_tokens_old[i]
corrected_text = decode_tokens[:len(text)]
corrected_text, sub_details = get_errors(corrected_text, text)
text_new += corrected_text
sub_details = [(i[0], i[1], idx + i[2], idx + i[3]) for i in sub_details]
details.extend(sub_details)
return text_new, details
def batch_macbert_correct(self, texts: List[str], max_length: int = 128):
"""
句子纠错
:param texts: list[str], sentence list
:param max_length: int, max length of each sentence
:return: corrected_text, list[list], [error_word, correct_word, begin_pos, end_pos]
"""
result = []
inputs = self.tokenizer(texts, padding=True, return_tensors='pt').to(device)
with torch.no_grad():
outputs = self.model(**inputs)
for ids, (i, text) in zip(outputs.logits, enumerate(texts)):
text_new = ''
details = []
corrected_text = self.tokenizer.decode((torch.argmax(ids, dim=-1) * inputs.attention_mask[i]),
skip_special_tokens=True).replace(' ', '')
corrected_text, sub_details = get_errors(corrected_text, text)
text_new += corrected_text
sub_details = [(i[0], i[1], i[2], i[3]) for i in sub_details]
details.extend(sub_details)
result.append([text_new, details])
return result
if __name__ == "__main__":
m = MacBertCorrector()
error_sentences = [
'内容提要——在知识产权学科领域里',
'疝気医院那好 为老人让坐,疝気专科百科问答',
'少先队员因该为老人让坐',
'少 先 队 员 因 该 为 老人让坐',
'机七学习是人工智能领遇最能体现智能的一个分知',
'今天心情很好',
'老是较书。',
'遇到一位很棒的奴生跟我聊天。',
'他的语说的很好,法语也不错',
'他法语说的很好,的语也不错',
'他们的吵翻很不错,再说他们做的咖喱鸡也好吃',
'影像小孩子想的快,学习管理的斑法',
'餐厅的换经费产适合约会',
'走路真的麻坊,我也没有喝的东西,在家汪了',
'因为爸爸在看录音机,所以我没得看',
'不过在许多传统国家,女人向未得到平等',
'妈妈说:"别趴地上了,快起来,你还吃饭吗?",我说:"好。"就扒起来了。',
'你说:“怎么办?”我怎么知道?',
'我父母们常常说:“那时候吃的东西太少,每天只能吃一顿饭。”想一想,人们都快要饿死,谁提出化肥和农药的污染。',
'这本新书《居里夫人传》将的很生动有趣',
'֍我喜欢吃鸡,公鸡、母鸡、白切鸡、乌鸡、紫燕鸡……֍新的食谱',
'注意:“跨类保护”不等于“全类保护”。',
'12.——对比文件中未公开的数值和对比文件中已经公开的中间值具有新颖性;',
'《著作权法》(2020修正)第23条:“自然人的作品,其发表权、本法第',
'三步检验法(三步检验标准)(three-step test):若要',
'①申请人应提交,顺应了国家“健全创新激励',
'①申请人应提交,太平天国领导人洪仁玕。',
' 部分优先权:',
'实施其专利的行为(生产经营≠营利≠商业经营)',
'实施,i can speak chinese, can i spea english. ? hello.',
"我不唉“看 琅擤琊榜”",
]
t1 = time.time()
for sent in error_sentences:
corrected_sent, err = m.macbert_correct(sent, 0.6)
print("original sentence:{} => {} err:{}".format(sent, corrected_sent, err))
print('[single]spend time:', time.time() - t1)
t2 = time.time()
res = m.batch_macbert_correct(error_sentences)
for sent, r in zip(error_sentences, res):
print("original sentence:{} => {} err:{}".format(sent, r[0], r[1]))
print('[batch]spend time:', time.time() - t2)

46
predict.py

@ -0,0 +1,46 @@
import operator
import torch
from transformers import BertTokenizerFast, BertForMaskedLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese")
model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese")
model.to(device)
texts = ["今天新情很好,你找到你最喜欢的工作,我也很高心。", "今天新情很好,你找到你最喜欢的工作,我也很高心。"]
with torch.no_grad():
input = tokenizer(texts, padding=True, return_tensors='pt').to(device)
print(input)
input_ids = input['input_ids'].to(device)
token_type_ids = input["token_type_ids"].to(device)
attention_mask = input['attention_mask'].to(device)
print()
outputs = model(input_ids,token_type_ids,attention_mask)
def get_errors(corrected_text, origin_text):
sub_details = []
for i, ori_char in enumerate(origin_text):
if ori_char in [' ', '', '', '', '', '', '\n', '', '', '']:
# add unk word
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
continue
if i >= len(corrected_text):
continue
if ori_char != corrected_text[i]:
if ori_char.lower() == corrected_text[i]:
# pass english upper char
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
continue
sub_details.append((ori_char, corrected_text[i], i, i + 1))
sub_details = sorted(sub_details, key=operator.itemgetter(2))
return corrected_text, sub_details
result = []
for ids, text in zip(outputs.logits, texts):
_text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
corrected_text = _text[:len(text)]
corrected_text, details = get_errors(corrected_text, text)
print(text, ' => ', corrected_text, details)
result.append((text, corrected_text, details))
print(result)

200
preprocess.py

@ -0,0 +1,200 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
@description:
"""
import os
import sys
from codecs import open
from sklearn.model_selection import train_test_split
from lxml import etree
from xml.dom import minidom
sys.path.append('../..')
from pycorrector.utils.io_utils import save_json
from pycorrector import traditional2simplified
pwd_path = os.path.abspath(os.path.dirname(__file__))
def read_data(fp):
for fn in os.listdir(fp):
if fn.endswith('ing.sgml'):
with open(os.path.join(fp, fn), 'r', encoding='utf-8', errors='ignore') as f:
item = []
for line in f:
if line.strip().startswith('<ESSAY') and len(item) > 0:
yield ''.join(item)
item = [line.strip()]
elif line.strip().startswith('<'):
item.append(line.strip())
def proc_item(item):
"""
处理训练数据集
Args:
item:
Returns:
list
"""
root = etree.XML(item)
passages = dict()
mistakes = []
for passage in root.xpath('/ESSAY/TEXT/PASSAGE'):
passages[passage.get('id')] = traditional2simplified(passage.text)
for mistake in root.xpath('/ESSAY/MISTAKE'):
mistakes.append({'id': mistake.get('id'),
'location': int(mistake.get('location')) - 1,
'wrong': traditional2simplified(mistake.xpath('./WRONG/text()')[0].strip()),
'correction': traditional2simplified(mistake.xpath('./CORRECTION/text()')[0].strip())})
rst_items = dict()
def get_passages_by_id(pgs, _id):
p = pgs.get(_id)
if p:
return p
_id = _id[:-1] + str(int(_id[-1]) + 1)
p = pgs.get(_id)
if p:
return p
raise ValueError(f'passage not found by {_id}')
for mistake in mistakes:
if mistake['id'] not in rst_items.keys():
rst_items[mistake['id']] = {'original_text': get_passages_by_id(passages, mistake['id']),
'wrong_ids': [],
'correct_text': get_passages_by_id(passages, mistake['id'])}
ori_text = rst_items[mistake['id']]['original_text']
cor_text = rst_items[mistake['id']]['correct_text']
if len(ori_text) == len(cor_text):
if ori_text[mistake['location']] in mistake['wrong']:
rst_items[mistake['id']]['wrong_ids'].append(mistake['location'])
wrong_char_idx = mistake['wrong'].index(ori_text[mistake['location']])
start = mistake['location'] - wrong_char_idx
end = start + len(mistake['wrong'])
rst_items[mistake['id']][
'correct_text'] = f'{cor_text[:start]}{mistake["correction"]}{cor_text[end:]}'
else:
print(f'error line:\n{mistake["id"]}\n{ori_text}\n{cor_text}')
rst = []
for k in rst_items.keys():
if len(rst_items[k]['correct_text']) == len(rst_items[k]['original_text']):
rst.append({'id': k, **rst_items[k]})
else:
text = rst_items[k]['correct_text']
rst.append({'id': k, 'correct_text': text, 'original_text': text, 'wrong_ids': []})
return rst
def proc_test_set(fp):
"""
生成sighan15的测试集
Args:
fp:
Returns:
"""
inputs = dict()
with open(os.path.join(fp, 'SIGHAN15_CSC_TestInput.txt'), 'r', encoding='utf-8') as f:
for line in f:
pid = line[5:14]
text = line[16:].strip()
inputs[pid] = text
rst = []
with open(os.path.join(fp, 'SIGHAN15_CSC_TestTruth.txt'), 'r', encoding='utf-8') as f:
for line in f:
pid = line[0:9]
mistakes = line[11:].strip().split(', ')
if len(mistakes) <= 1:
text = traditional2simplified(inputs[pid])
rst.append({'id': pid,
'original_text': text,
'wrong_ids': [],
'correct_text': text})
else:
wrong_ids = []
original_text = inputs[pid]
cor_text = inputs[pid]
for i in range(len(mistakes) // 2):
idx = int(mistakes[2 * i]) - 1
cor_char = mistakes[2 * i + 1]
wrong_ids.append(idx)
cor_text = f'{cor_text[:idx]}{cor_char}{cor_text[idx + 1:]}'
original_text = traditional2simplified(original_text)
cor_text = traditional2simplified(cor_text)
if len(original_text) != len(cor_text):
print('error line:\n', pid)
print(original_text)
print(cor_text)
continue
rst.append({'id': pid,
'original_text': original_text,
'wrong_ids': wrong_ids,
'correct_text': cor_text})
return rst
def parse_cged_file(file_dir):
rst = []
for fn in os.listdir(file_dir):
if fn.endswith('.xml'):
path = os.path.join(file_dir, fn)
print('Parse data from %s' % path)
dom_tree = minidom.parse(path)
docs = dom_tree.documentElement.getElementsByTagName('DOC')
for doc in docs:
id = ''
text = ''
texts = doc.getElementsByTagName('TEXT')
for i in texts:
id = i.getAttribute('id')
# Input the text
text = i.childNodes[0].data.strip()
# Input the correct text
correction = doc.getElementsByTagName('CORRECTION')[0]. \
childNodes[0].data.strip()
wrong_ids = []
for error in doc.getElementsByTagName('ERROR'):
start_off = error.getAttribute('start_off')
end_off = error.getAttribute('end_off')
if start_off and end_off:
for i in range(int(start_off), int(end_off)+1):
wrong_ids.append(i)
source = text.strip()
target = correction.strip()
pair = [source, target]
if pair not in rst:
rst.append({'id': id,
'original_text': source,
'wrong_ids': wrong_ids,
'correct_text': target
})
save_json(rst, os.path.join(pwd_path, 'output/cged.json'))
return rst
def main():
# 注意:该训练样本较少,仅作为模型测试使用
# parse_cged_file(os.path.join(pwd_path, '../data/cn/CGED/'))
sighan15_dir = os.path.join(pwd_path, '../data/cn/sighan_2015/')
rst_items = []
test_lst = proc_test_set(sighan15_dir)
for item in read_data(sighan15_dir):
rst_items += proc_item(item)
# 拆分训练与测试
print('data_size:', len(rst_items))
train_lst, dev_lst = train_test_split(rst_items, test_size=0.1, random_state=42)
save_json(train_lst, os.path.join(pwd_path, 'output/train.json'))
save_json(dev_lst, os.path.join(pwd_path, 'output/dev.json'))
save_json(test_lst, os.path.join(pwd_path, 'output/test.json'))
if __name__ == '__main__':
main()

68
reader.py

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
@description:
"""
import os
import json
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizerFast
from torch.utils.data import DataLoader
class DataCollator:
def __init__(self, tokenizer: BertTokenizerFast):
self.tokenizer = tokenizer
def __call__(self, data):
ori_texts, cor_texts, wrong_idss = zip(*data)
encoded_texts = [self.tokenizer(t, return_offsets_mapping=True, add_special_tokens=False) for t in ori_texts]
max_len = max([len(t['input_ids']) for t in encoded_texts]) + 2
det_labels = torch.zeros(len(ori_texts), max_len).long()
for i, (encoded_text, wrong_ids) in enumerate(zip(encoded_texts, wrong_idss)):
off_mapping = encoded_text['offset_mapping']
for idx in wrong_ids:
for j, (b, e) in enumerate(off_mapping):
if b <= idx < e:
# j+1是因为前面的 CLS token
det_labels[i, j + 1] = 1
break
return list(ori_texts), list(cor_texts), det_labels
class CscDataset(Dataset):
def __init__(self, file_path):
self.data = json.load(open(file_path, 'r', encoding='utf-8'))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]['original_text'], self.data[index]['correct_text'], self.data[index]['wrong_ids']
def make_loaders(collate_fn, train_path='', valid_path='', test_path='',
batch_size=32, num_workers=4):
train_loader = None
if train_path and os.path.exists(train_path):
train_loader = DataLoader(CscDataset(train_path),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn)
valid_loader = None
if valid_path and os.path.exists(valid_path):
valid_loader = DataLoader(CscDataset(valid_path),
batch_size=batch_size,
num_workers=num_workers,
collate_fn=collate_fn)
test_loader = None
if test_path and os.path.exists(test_path):
test_loader = DataLoader(CscDataset(test_path),
batch_size=batch_size,
num_workers=num_workers,
collate_fn=collate_fn)
return train_loader, valid_loader, test_loader

131
rewrite.log

@ -0,0 +1,131 @@
2023-05-12 11:44:02,567 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses.
WARNING: This is a development server. Do not use it in a production deployment.
2023-05-12 11:44:02,567 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.116:16000/ (Press CTRL+C to quit)
2023-05-12 11:47:41,934 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses.
WARNING: This is a development server. Do not use it in a production deployment.
2023-05-12 11:47:41,935 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.116:16000/ (Press CTRL+C to quit)
2023-05-12 11:47:43,032 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.116 - - [12/May/2023 11:47:43] "POST /predict HTTP/1.1" 200 -
2023-05-12 11:48:12,651 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.116 - - [12/May/2023 11:48:12] "POST /predict HTTP/1.1" 200 -
2023-05-12 11:48:32,073 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.116 - - [12/May/2023 11:48:32] "POST /predict HTTP/1.1" 200 -
2023-05-12 11:49:22,664 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses.
WARNING: This is a development server. Do not use it in a production deployment.
2023-05-12 11:49:22,664 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.116:16000/ (Press CTRL+C to quit)
2023-05-12 11:49:25,857 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.116 - - [12/May/2023 11:49:25] "POST /predict HTTP/1.1" 200 -
2023-06-13 19:24:21,082 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses.
WARNING: This is a development server. Do not use it in a production deployment.
2023-06-13 19:24:21,083 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.115:16000/ (Press CTRL+C to quit)
2023-06-13 19:25:24,092 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [13/Jun/2023 19:25:24] "POST /predict HTTP/1.1" 200 -
2023-06-13 19:26:11,123 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [13/Jun/2023 19:26:11] "POST /predict HTTP/1.1" 200 -
2023-06-13 19:35:27,653 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [13/Jun/2023 19:35:27] "POST /predict HTTP/1.1" 200 -
2023-06-13 19:37:38,520 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses.
WARNING: This is a development server. Do not use it in a production deployment.
2023-06-13 19:37:38,520 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.115:16000/ (Press CTRL+C to quit)
2023-06-13 19:44:29,271 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py[line:1892] - ERROR: Exception on /predict [POST]
Traceback (most recent call last):
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\tokenization_utils_base.py", line 248, in __getattr__
return self.data[item]
KeyError: 'size'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 2447, in wsgi_app
response = self.full_dispatch_request()
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1952, in full_dispatch_request
rv = self.handle_user_exception(e)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1821, in handle_user_exception
reraise(exc_type, exc_value, tb)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\_compat.py", line 39, in reraise
raise value
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1950, in full_dispatch_request
rv = self.dispatch_request()
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1936, in dispatch_request
return self.view_functions[rule.endpoint](**req.view_args)
File "E:/pycharm_workspace/macbert/flask_macbert.py", line 98, in handle_query
return_list = main(texts)
File "E:/pycharm_workspace/macbert/flask_macbert.py", line 79, in main
outputs = model(a)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\torch\nn\modules\module.py", line 1120, in _call_impl
result = forward_call(*input, **kwargs)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\models\bert\modeling_bert.py", line 1345, in forward
return_dict=return_dict,
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\torch\nn\modules\module.py", line 1120, in _call_impl
result = forward_call(*input, **kwargs)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\models\bert\modeling_bert.py", line 944, in forward
input_shape = input_ids.size()
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\tokenization_utils_base.py", line 250, in __getattr__
raise AttributeError
AttributeError
2023-06-13 19:44:29,291 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [13/Jun/2023 19:44:29] "POST /predict HTTP/1.1" 500 -
2023-06-13 19:45:26,160 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses.
WARNING: This is a development server. Do not use it in a production deployment.
2023-06-13 19:45:26,161 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.115:16000/ (Press CTRL+C to quit)
2023-06-13 19:45:28,616 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py[line:1892] - ERROR: Exception on /predict [POST]
Traceback (most recent call last):
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\tokenization_utils_base.py", line 248, in __getattr__
return self.data[item]
KeyError: 'size'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 2447, in wsgi_app
response = self.full_dispatch_request()
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1952, in full_dispatch_request
rv = self.handle_user_exception(e)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1821, in handle_user_exception
reraise(exc_type, exc_value, tb)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\_compat.py", line 39, in reraise
raise value
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1950, in full_dispatch_request
rv = self.dispatch_request()
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1936, in dispatch_request
return self.view_functions[rule.endpoint](**req.view_args)
File "E:/pycharm_workspace/macbert/flask_macbert.py", line 98, in handle_query
return_list = main(texts)
File "E:/pycharm_workspace/macbert/flask_macbert.py", line 79, in main
outputs = model(ids)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\torch\nn\modules\module.py", line 1120, in _call_impl
result = forward_call(*input, **kwargs)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\models\bert\modeling_bert.py", line 1345, in forward
return_dict=return_dict,
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\torch\nn\modules\module.py", line 1120, in _call_impl
result = forward_call(*input, **kwargs)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\models\bert\modeling_bert.py", line 944, in forward
input_shape = input_ids.size()
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\transformers\tokenization_utils_base.py", line 250, in __getattr__
raise AttributeError
AttributeError
2023-06-13 19:45:28,623 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [13/Jun/2023 19:45:28] "POST /predict HTTP/1.1" 500 -
2023-06-13 19:45:58,576 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses.
WARNING: This is a development server. Do not use it in a production deployment.
2023-06-13 19:45:58,577 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.115:16000/ (Press CTRL+C to quit)
2023-06-13 19:50:29,338 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses.
WARNING: This is a development server. Do not use it in a production deployment.
2023-06-13 19:50:29,339 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.115:16000/ (Press CTRL+C to quit)
2023-06-13 19:50:32,134 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py[line:1892] - ERROR: Exception on /predict [POST]
Traceback (most recent call last):
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 2447, in wsgi_app
response = self.full_dispatch_request()
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1952, in full_dispatch_request
rv = self.handle_user_exception(e)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1821, in handle_user_exception
reraise(exc_type, exc_value, tb)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\_compat.py", line 39, in reraise
raise value
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1950, in full_dispatch_request
rv = self.dispatch_request()
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\flask\app.py", line 1936, in dispatch_request
return self.view_functions[rule.endpoint](**req.view_args)
File "E:/pycharm_workspace/macbert/flask_macbert.py", line 98, in handle_query
return_list = main(texts)
File "E:/pycharm_workspace/macbert/flask_macbert.py", line 79, in main
outputs = model.run(None, ids_input)
File "C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\torch\nn\modules\module.py", line 1178, in __getattr__
type(self).__name__, name))
AttributeError: 'BertForMaskedLM' object has no attribute 'run'
2023-06-13 19:50:32,146 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [13/Jun/2023 19:50:32] "POST /predict HTTP/1.1" 500 -
2023-06-14 14:07:51,042 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - WARNING: * Running on all addresses.
WARNING: This is a development server. Do not use it in a production deployment.
2023-06-14 14:07:51,047 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: * Running on http://192.168.31.115:16000/ (Press CTRL+C to quit)
2023-06-14 14:07:54,329 - C:\Users\83887\Anaconda3\envs\bertsum_nezha\lib\site-packages\werkzeug\_internal.py[line:225] - INFO: 192.168.31.115 - - [14/Jun/2023 14:07:54] "POST /predict HTTP/1.1" 200 -

151
softmaskedbert4csc.py

@ -0,0 +1,151 @@
"""
@Time : 2021-01-21 12:00:59
@File : modeling_soft_masked_bert.py
@Author : Abtion
@Email : abtion{at}outlook.com
"""
from abc import ABC
from collections import OrderedDict
import transformers as tfs
import torch
from torch import nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertEncoder, BertOnlyMLMHead
from transformers.modeling_utils import ModuleUtilsMixin
from pycorrector.macbert.base_model import CscTrainingModel
class DetectionNetwork(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.gru = nn.GRU(
self.config.hidden_size,
self.config.hidden_size // 2,
num_layers=2,
batch_first=True,
dropout=self.config.hidden_dropout_prob,
bidirectional=True,
)
self.sigmoid = nn.Sigmoid()
self.linear = nn.Linear(self.config.hidden_size, 1)
def forward(self, hidden_states):
out, _ = self.gru(hidden_states)
prob = self.linear(out)
prob = self.sigmoid(prob)
return prob
class CorrectionNetwork(torch.nn.Module, ModuleUtilsMixin):
def __init__(self, config, tokenizer, device):
super().__init__()
self.config = config
self.tokenizer = tokenizer
self.embeddings = BertEmbeddings(self.config)
self.bert = BertEncoder(self.config)
self.mask_token_id = self.tokenizer.mask_token_id
self.cls = BertOnlyMLMHead(self.config)
self._device = device
def forward(self, texts, prob, embed=None, cor_labels=None, residual_connection=False):
if cor_labels is not None:
text_labels = self.tokenizer(cor_labels, padding=True, return_tensors='pt')['input_ids']
# torch的cross entropy loss 会忽略-100的label
text_labels[text_labels == 0] = -100
text_labels = text_labels.to(self._device)
else:
text_labels = None
encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt')
encoded_texts.to(self._device)
if embed is None:
embed = self.embeddings(input_ids=encoded_texts['input_ids'],
token_type_ids=encoded_texts['token_type_ids'])
# 此处较原文有一定改动,做此改动意在完整保留type_ids及position_ids的embedding。
mask_embed = self.embeddings(torch.ones_like(prob.squeeze(-1)).long() * self.mask_token_id).detach()
# 此处为原文实现
# mask_embed = self.embeddings(torch.tensor([[self.mask_token_id]], device=self._device)).detach()
cor_embed = prob * mask_embed + (1 - prob) * embed
input_shape = encoded_texts['input_ids'].size()
device = encoded_texts['input_ids'].device
extended_attention_mask = self.get_extended_attention_mask(encoded_texts['attention_mask'],
input_shape, device)
head_mask = self.get_head_mask(None, self.config.num_hidden_layers)
encoder_outputs = self.bert(
cor_embed,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
return_dict=False,
)
sequence_output = encoder_outputs[0]
sequence_output = sequence_output + embed if residual_connection else sequence_output
prediction_scores = self.cls(sequence_output)
out = (prediction_scores, sequence_output)
# Masked language modeling softmax layer
if text_labels is not None:
loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
cor_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), text_labels.view(-1))
out = (cor_loss,) + out
return out
def load_from_transformers_state_dict(self, gen_fp):
state_dict = OrderedDict()
gen_state_dict = tfs.AutoModelForMaskedLM.from_pretrained(gen_fp).state_dict()
for k, v in gen_state_dict.items():
name = k
if name.startswith('bert'):
name = name[5:]
if name.startswith('encoder'):
name = f'corrector.{name[8:]}'
if 'gamma' in name:
name = name.replace('gamma', 'weight')
if 'beta' in name:
name = name.replace('beta', 'bias')
state_dict[name] = v
self.load_state_dict(state_dict, strict=False)
class SoftMaskedBert4Csc(CscTrainingModel, ABC):
def __init__(self, cfg, tokenizer):
super().__init__(cfg)
self.cfg = cfg
self.config = tfs.AutoConfig.from_pretrained(cfg.MODEL.BERT_CKPT)
self.detector = DetectionNetwork(self.config)
self.tokenizer = tokenizer
self.corrector = CorrectionNetwork(self.config, tokenizer, cfg.MODEL.DEVICE)
self.corrector.load_from_transformers_state_dict(self.cfg.MODEL.BERT_CKPT)
self._device = cfg.MODEL.DEVICE
def forward(self, texts, cor_labels=None, det_labels=None):
encoded_texts = self.tokenizer(texts, padding=True, return_tensors='pt')
encoded_texts.to(self._device)
embed = self.corrector.embeddings(input_ids=encoded_texts['input_ids'],
token_type_ids=encoded_texts['token_type_ids'])
prob = self.detector(embed)
cor_out = self.corrector(texts, prob, embed, cor_labels, residual_connection=True)
if det_labels is not None:
det_loss_fct = nn.BCELoss()
# pad部分不计算损失
active_loss = encoded_texts['attention_mask'].view(-1, prob.shape[1]) == 1
active_probs = prob.view(-1, prob.shape[1])[active_loss]
active_labels = det_labels[active_loss]
det_loss = det_loss_fct(active_probs, active_labels.float())
outputs = (det_loss, cor_out[0], prob.squeeze(-1)) + cor_out[1:]
else:
outputs = (prob.squeeze(-1),) + cor_out
return outputs
def load_from_transformers_state_dict(self, gen_fp):
"""
从transformers加载预训练权重
:param gen_fp:
:return:
"""
self.corrector.load_from_transformers_state_dict(gen_fp)

133
train.py

@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com)
@description:
"""
import os
import sys
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import BertTokenizerFast, BertForMaskedLM
import argparse
from collections import OrderedDict
from loguru import logger
sys.path.append('../..')
from pycorrector.macbert.reader import make_loaders, DataCollator
from pycorrector.macbert.macbert4csc import MacBert4Csc
from pycorrector.macbert.softmaskedbert4csc import SoftMaskedBert4Csc
from pycorrector.macbert import preprocess
from pycorrector.macbert.defaults import _C as cfg
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
def args_parse(config_file=''):
parser = argparse.ArgumentParser(description="csc")
parser.add_argument(
"--config_file", default="train_macbert4csc.yml", help="path to config file", type=str
)
parser.add_argument("--opts", help="Modify config options using the command-line key value", default=[],
nargs=argparse.REMAINDER)
args = parser.parse_args()
config_file = args.config_file or config_file
cfg.merge_from_file(config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
logger.info(args)
if config_file != '':
logger.info("Loaded configuration file {}".format(config_file))
with open(config_file, 'r') as cf:
config_str = "\n" + cf.read()
logger.info(config_str)
logger.info("Running with config:\n{}".format(cfg))
return cfg
def main():
cfg = args_parse()
# 如果不存在训练文件则先处理数据
if not os.path.exists(cfg.DATASETS.TRAIN):
logger.debug('preprocess data')
preprocess.main()
logger.info(f'load model, model arch: {cfg.MODEL.NAME}')
tokenizer = BertTokenizerFast.from_pretrained(cfg.MODEL.BERT_CKPT)
collator = DataCollator(tokenizer=tokenizer)
# 加载数据
train_loader, valid_loader, test_loader = make_loaders(collator, train_path=cfg.DATASETS.TRAIN,
valid_path=cfg.DATASETS.VALID, test_path=cfg.DATASETS.TEST,
batch_size=cfg.SOLVER.BATCH_SIZE, num_workers=4)
if cfg.MODEL.NAME == 'softmaskedbert4csc':
model = SoftMaskedBert4Csc(cfg, tokenizer)
elif cfg.MODEL.NAME == 'macbert4csc':
model = MacBert4Csc(cfg, tokenizer)
else:
raise ValueError("model not found.")
# 加载之前保存的模型,继续训练
if cfg.MODEL.WEIGHTS and os.path.exists(cfg.MODEL.WEIGHTS):
model.load_from_checkpoint(checkpoint_path=cfg.MODEL.WEIGHTS, cfg=cfg, map_location=device, tokenizer=tokenizer)
# 配置模型保存参数
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
ckpt_callback = ModelCheckpoint(
monitor='val_loss',
dirpath=cfg.OUTPUT_DIR,
filename='{epoch:02d}-{val_loss:.2f}',
save_top_k=1,
mode='min'
)
# 训练模型
logger.info('train model ...')
trainer = pl.Trainer(max_epochs=cfg.SOLVER.MAX_EPOCHS,
gpus=None if device == torch.device('cpu') else cfg.MODEL.GPU_IDS,
accumulate_grad_batches=cfg.SOLVER.ACCUMULATE_GRAD_BATCHES,
callbacks=[ckpt_callback])
# 进行训练
# train_loader中有数据
torch.autograd.set_detect_anomaly(True)
if 'train' in cfg.MODE and train_loader and len(train_loader) > 0:
if valid_loader and len(valid_loader) > 0:
trainer.fit(model, train_loader, valid_loader)
else:
trainer.fit(model, train_loader)
logger.info('train model done.')
# 模型转为transformers可加载
if ckpt_callback and len(ckpt_callback.best_model_path) > 0:
ckpt_path = ckpt_callback.best_model_path
elif cfg.MODEL.WEIGHTS and os.path.exists(cfg.MODEL.WEIGHTS):
ckpt_path = cfg.MODEL.WEIGHTS
else:
ckpt_path = ''
logger.info(f'ckpt_path: {ckpt_path}')
if ckpt_path and os.path.exists(ckpt_path):
model.load_state_dict(torch.load(ckpt_path)['state_dict'])
# 先保存原始transformer bert model
tokenizer.save_pretrained(cfg.OUTPUT_DIR)
bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT)
bert.save_pretrained(cfg.OUTPUT_DIR)
state_dict = torch.load(ckpt_path)['state_dict']
new_state_dict = OrderedDict()
if cfg.MODEL.NAME in ['macbert4csc']:
for k, v in state_dict.items():
if k.startswith('bert.'):
new_state_dict[k[5:]] = v
else:
new_state_dict = state_dict
# 再保存finetune训练后的模型文件,替换原始的pytorch_model.bin
torch.save(new_state_dict, os.path.join(cfg.OUTPUT_DIR, 'pytorch_model.bin'))
# 进行测试的逻辑同训练
if 'test' in cfg.MODE and test_loader and len(test_loader) > 0:
trainer.test(model, test_loader)
if __name__ == '__main__':
main()

24
train_macbert4csc.yml

@ -0,0 +1,24 @@
MODEL:
BERT_CKPT: "hfl/chinese-macbert-base"
DEVICE: "cuda"
NAME: "macbert4csc"
GPU_IDS: [0]
# [loss_coefficient]
HYPER_PARAMS: [0.3]
#WEIGHTS: "output/macbert4csc/epoch=6-val_loss=0.07.ckpt"
WEIGHTS: ""
DATASETS:
TRAIN: "output/train.json"
VALID: "output/dev.json"
TEST: "output/test.json"
SOLVER:
BASE_LR: 5e-5
WEIGHT_DECAY: 0.01
BATCH_SIZE: 32
MAX_EPOCHS: 10
ACCUMULATE_GRAD_BATCHES: 4
OUTPUT_DIR: "output/macbert4csc"
MODE: ["train", "test"]

24
train_softmaskedbert4csc.yml

@ -0,0 +1,24 @@
MODEL:
BERT_CKPT: "bert-base-chinese"
DEVICE: "cuda"
NAME: "softmaskedbert4csc"
GPU_IDS: [0]
# [loss_coefficient]
HYPER_PARAMS: [0.8]
#WEIGHTS: "output/softmaskedbert4csc/epoch=2-val_loss=0.07.ckpt"
WEIGHTS: ""
DATASETS:
TRAIN: "output/train.json"
VALID: "output/dev.json"
TEST: "output/test.json"
SOLVER:
BASE_LR: 0.0001
WEIGHT_DECAY: 5e-8
BATCH_SIZE: 32
MAX_EPOCHS: 10
ACCUMULATE_GRAD_BATCHES: 4
OUTPUT_DIR: "output/softmaskedbert4csc"
MODE: ["train", "test"]
Loading…
Cancel
Save