import operator import torch from transformers import BertTokenizerFast, BertForMaskedLM device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese") model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese") model.to(device) texts = ["今天新情很好,你找到你最喜欢的工作,我也很高心。", "今天新情很好,你找到你最喜欢的工作,我也很高心。"] with torch.no_grad(): input = tokenizer(texts, padding=True, return_tensors='pt').to(device) print(input) input_ids = input['input_ids'].to(device) token_type_ids = input["token_type_ids"].to(device) attention_mask = input['attention_mask'].to(device) print() outputs = model(input_ids,token_type_ids,attention_mask) def get_errors(corrected_text, origin_text): sub_details = [] for i, ori_char in enumerate(origin_text): if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']: # add unk word corrected_text = corrected_text[:i] + ori_char + corrected_text[i:] continue if i >= len(corrected_text): continue if ori_char != corrected_text[i]: if ori_char.lower() == corrected_text[i]: # pass english upper char corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] continue sub_details.append((ori_char, corrected_text[i], i, i + 1)) sub_details = sorted(sub_details, key=operator.itemgetter(2)) return corrected_text, sub_details result = [] for ids, text in zip(outputs.logits, texts): _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '') corrected_text = _text[:len(text)] corrected_text, details = get_errors(corrected_text, text) print(text, ' => ', corrected_text, details) result.append((text, corrected_text, details)) print(result)