纠错任务
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

116 lines
4.8 KiB

# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com), okcd00(okcd00@qq.com)
@description:
"""
import sys
import torch
import argparse
from transformers import BertTokenizerFast
from loguru import logger
sys.path.append('../..')
from pycorrector.macbert.macbert4csc import MacBert4Csc
from pycorrector.macbert.softmaskedbert4csc import SoftMaskedBert4Csc
from pycorrector.macbert.macbert_corrector import get_errors
from pycorrector.macbert.defaults import _C as cfg
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Inference:
def __init__(self, ckpt_path='output/macbert4csc/epoch=09-val_loss=0.01.ckpt',
vocab_path='output/macbert4csc/vocab.txt',
cfg_path='train_macbert4csc.yml'):
logger.debug("device: {}".format(device))
self.tokenizer = BertTokenizerFast.from_pretrained(vocab_path)
cfg.merge_from_file(cfg_path)
if 'macbert4csc' in cfg_path:
self.model = MacBert4Csc.load_from_checkpoint(checkpoint_path=ckpt_path,
cfg=cfg,
map_location=device,
tokenizer=self.tokenizer)
elif 'softmaskedbert4csc' in cfg_path:
self.model = SoftMaskedBert4Csc.load_from_checkpoint(checkpoint_path=ckpt_path,
cfg=cfg,
map_location=device,
tokenizer=self.tokenizer)
else:
raise ValueError("model not found.")
self.model.to(device)
self.model.eval()
def predict(self, sentence_list):
"""
文本纠错模型预测
Args:
sentence_list: list
输入文本列表
Returns: tuple
corrected_texts(list)
"""
is_str = False
if isinstance(sentence_list, str):
is_str = True
sentence_list = [sentence_list]
corrected_texts = self.model.predict(sentence_list)
if is_str:
return corrected_texts[0]
return corrected_texts
def predict_with_error_detail(self, sentence_list):
"""
文本纠错模型预测,结果带错误位置信息
Args:
sentence_list: list
输入文本列表
Returns: tuple
corrected_texts(list), details(list)
"""
details = []
is_str = False
if isinstance(sentence_list, str):
is_str = True
sentence_list = [sentence_list]
corrected_texts = self.model.predict(sentence_list)
for corrected_text, text in zip(corrected_texts, sentence_list):
corrected_text, sub_details = get_errors(corrected_text, text)
details.append(sub_details)
if is_str:
return corrected_texts[0], details[0]
return corrected_texts, details
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="infer")
parser.add_argument("--ckpt_path", default="output/macbert4csc/epoch=09-val_loss=0.01.ckpt",
help="path to config file", type=str)
parser.add_argument("--vocab_path", default="output/macbert4csc/vocab.txt", help="path to config file", type=str)
parser.add_argument("--config_file", default="train_macbert4csc.yml", help="path to config file", type=str)
args = parser.parse_args()
m = Inference(args.ckpt_path, args.vocab_path, args.config_file)
inputs = [
'它的本领是呼风唤雨,因此能灭火防灾。狎鱼后面是獬豸。獬豸通常头上长着独角,有时又被称为独角羊。它很聪彗,而且明辨是非,象征着大公无私,又能镇压斜恶。',
'老是较书。',
'少先队 员因该 为老人让 坐',
'感谢等五分以后,碰到一位很棒的奴生跟我可聊。',
'遇到一位很棒的奴生跟我聊天。',
'遇到一位很美的女生跟我疗天。',
'他们只能有两个选择:接受降新或自动离职。',
'王天华开心得一直说话。',
'你说:“怎么办?”我怎么知道?',
]
outputs = m.predict(inputs)
for a, b in zip(inputs, outputs):
print('input :', a)
print('predict:', b)
print()
# 在sighan2015 test数据集评估模型
# macbert4csc Sentence Level: acc:0.7845, precision:0.8174, recall:0.7256, f1:0.7688, cost time:10.79 s
# softmaskedbert4csc Sentence Level: acc:0.6964, precision:0.8065, recall:0.5064, f1:0.6222, cost time:16.20 s
from pycorrector.utils.eval import eval_sighan2015_by_model
eval_sighan2015_by_model(m.predict_with_error_detail)