|
@ -4,7 +4,8 @@ from flask import request |
|
|
import operator |
|
|
import operator |
|
|
import torch |
|
|
import torch |
|
|
from transformers import BertTokenizerFast, BertForMaskedLM |
|
|
from transformers import BertTokenizerFast, BertForMaskedLM |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
device = "cpu" |
|
|
import uuid |
|
|
import uuid |
|
|
import json |
|
|
import json |
|
|
from threading import Thread |
|
|
from threading import Thread |
|
@ -121,20 +122,53 @@ def get_errors(corrected_text, origin_text): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(texts): |
|
|
def main(texts): |
|
|
|
|
|
batch_size_sentent_class = [] |
|
|
|
|
|
|
|
|
|
|
|
for text in texts: |
|
|
|
|
|
batch_size_sentent_class.append(SentenceUlit(text)) |
|
|
|
|
|
|
|
|
|
|
|
batch_pre = [] |
|
|
|
|
|
batch_nums = [] |
|
|
|
|
|
for sentent_class in batch_size_sentent_class: |
|
|
|
|
|
sentents = sentent_class.sentence_batch |
|
|
|
|
|
batch_pre.extend(sentents) |
|
|
|
|
|
batch_nums.append(len(sentents)) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
with torch.no_grad(): |
|
|
outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device)) |
|
|
# input_pre = tokenizer(batch_pre, padding=True, return_tensors='pt').to(device) |
|
|
|
|
|
# input_ids = input_pre['input_ids'].to(device) |
|
|
|
|
|
# token_type_ids = input_pre["token_type_ids"].to(device) |
|
|
|
|
|
# attention_mask = input_pre['attention_mask'].to(device) |
|
|
|
|
|
# outputs = model(input_ids, token_type_ids, attention_mask) |
|
|
|
|
|
outputs = model(**tokenizer(batch_pre, padding=True, return_tensors='pt').to(device)) |
|
|
|
|
|
|
|
|
result = [] |
|
|
batch_res = [] |
|
|
print(outputs.logits) |
|
|
|
|
|
for ids, text in zip(outputs.logits, texts): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for ids,data_dan in zip(outputs.logits,batch_pre): |
|
|
_text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '') |
|
|
_text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '') |
|
|
corrected_text = _text[:len(text)] |
|
|
corrected_text = _text[:len(data_dan)] |
|
|
print(corrected_text) |
|
|
batch_res.append(corrected_text) |
|
|
|
|
|
print(batch_pre) |
|
|
|
|
|
print(batch_res) |
|
|
|
|
|
batch_new = [] |
|
|
|
|
|
index = 0 |
|
|
|
|
|
for i in batch_nums: |
|
|
|
|
|
index_new = index + i |
|
|
|
|
|
batch_new.append(batch_res[index:index_new]) |
|
|
|
|
|
index = index_new |
|
|
|
|
|
|
|
|
|
|
|
batch_pre_data = [] |
|
|
|
|
|
for dan, sentent_class in zip(batch_new, batch_size_sentent_class): |
|
|
|
|
|
sentent_class.inf_ulit(dan) |
|
|
|
|
|
batch_pre_data.append("".join(sentent_class.sentence_list)) |
|
|
|
|
|
|
|
|
|
|
|
result = [] |
|
|
|
|
|
for text, corrected_text in zip(texts,batch_pre_data): |
|
|
corrected_text, details = get_errors(corrected_text, text) |
|
|
corrected_text, details = get_errors(corrected_text, text) |
|
|
result.append({"old": text, |
|
|
result.append({"old": text, |
|
|
"new": corrected_text, |
|
|
"new": corrected_text, |
|
|
"re_pos": details}) |
|
|
"re_pos": details}) |
|
|
|
|
|
|
|
|
return result |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -156,4 +190,4 @@ if __name__ == "__main__": |
|
|
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' |
|
|
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' |
|
|
# 日志格式 |
|
|
# 日志格式 |
|
|
) |
|
|
) |
|
|
app.run(host="0.0.0.0", port=16000, threaded=True, debug=False) |
|
|
app.run(host="0.0.0.0", port=16001, threaded=True, debug=False) |
|
|