Browse Source

更改英文处理

master
majiahui@haimaqingfan.com 2 years ago
parent
commit
9a653a5f17
  1. 50
      flask_macbert.py

50
flask_macbert.py

@ -4,7 +4,8 @@ from flask import request
import operator import operator
import torch import torch
from transformers import BertTokenizerFast, BertForMaskedLM from transformers import BertTokenizerFast, BertForMaskedLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
import uuid import uuid
import json import json
from threading import Thread from threading import Thread
@ -121,20 +122,53 @@ def get_errors(corrected_text, origin_text):
def main(texts): def main(texts):
batch_size_sentent_class = []
for text in texts:
batch_size_sentent_class.append(SentenceUlit(text))
batch_pre = []
batch_nums = []
for sentent_class in batch_size_sentent_class:
sentents = sentent_class.sentence_batch
batch_pre.extend(sentents)
batch_nums.append(len(sentents))
with torch.no_grad(): with torch.no_grad():
outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device)) # input_pre = tokenizer(batch_pre, padding=True, return_tensors='pt').to(device)
# input_ids = input_pre['input_ids'].to(device)
# token_type_ids = input_pre["token_type_ids"].to(device)
# attention_mask = input_pre['attention_mask'].to(device)
# outputs = model(input_ids, token_type_ids, attention_mask)
outputs = model(**tokenizer(batch_pre, padding=True, return_tensors='pt').to(device))
result = [] batch_res = []
print(outputs.logits)
for ids, text in zip(outputs.logits, texts):
for ids,data_dan in zip(outputs.logits,batch_pre):
_text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '') _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
corrected_text = _text[:len(text)] corrected_text = _text[:len(data_dan)]
print(corrected_text) batch_res.append(corrected_text)
print(batch_pre)
print(batch_res)
batch_new = []
index = 0
for i in batch_nums:
index_new = index + i
batch_new.append(batch_res[index:index_new])
index = index_new
batch_pre_data = []
for dan, sentent_class in zip(batch_new, batch_size_sentent_class):
sentent_class.inf_ulit(dan)
batch_pre_data.append("".join(sentent_class.sentence_list))
result = []
for text, corrected_text in zip(texts,batch_pre_data):
corrected_text, details = get_errors(corrected_text, text) corrected_text, details = get_errors(corrected_text, text)
result.append({"old": text, result.append({"old": text,
"new": corrected_text, "new": corrected_text,
"re_pos": details}) "re_pos": details})
return result return result
@ -156,4 +190,4 @@ if __name__ == "__main__":
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式 # 日志格式
) )
app.run(host="0.0.0.0", port=16000, threaded=True, debug=False) app.run(host="0.0.0.0", port=16001, threaded=True, debug=False)

Loading…
Cancel
Save