纠错任务
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

46 lines
2.0 KiB

import operator
import torch
from transformers import BertTokenizerFast, BertForMaskedLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese")
model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese")
model.to(device)
texts = ["今天新情很好,你找到你最喜欢的工作,我也很高心。", "今天新情很好,你找到你最喜欢的工作,我也很高心。"]
with torch.no_grad():
input = tokenizer(texts, padding=True, return_tensors='pt').to(device)
print(input)
input_ids = input['input_ids'].to(device)
token_type_ids = input["token_type_ids"].to(device)
attention_mask = input['attention_mask'].to(device)
print()
outputs = model(input_ids,token_type_ids,attention_mask)
def get_errors(corrected_text, origin_text):
sub_details = []
for i, ori_char in enumerate(origin_text):
if ori_char in [' ', '', '', '', '', '', '\n', '', '', '']:
# add unk word
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
continue
if i >= len(corrected_text):
continue
if ori_char != corrected_text[i]:
if ori_char.lower() == corrected_text[i]:
# pass english upper char
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
continue
sub_details.append((ori_char, corrected_text[i], i, i + 1))
sub_details = sorted(sub_details, key=operator.itemgetter(2))
return corrected_text, sub_details
result = []
for ids, text in zip(outputs.logits, texts):
_text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
corrected_text = _text[:len(text)]
corrected_text, details = get_errors(corrected_text, text)
print(text, ' => ', corrected_text, details)
result.append((text, corrected_text, details))
print(result)