Browse Source

更改英文处理

master
majiahui@haimaqingfan.com 2 years ago
parent
commit
9a653a5f17
  1. 50
      flask_macbert.py

50
flask_macbert.py

@ -4,7 +4,8 @@ from flask import request
import operator
import torch
from transformers import BertTokenizerFast, BertForMaskedLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
import uuid
import json
from threading import Thread
@ -121,20 +122,53 @@ def get_errors(corrected_text, origin_text):
def main(texts):
batch_size_sentent_class = []
for text in texts:
batch_size_sentent_class.append(SentenceUlit(text))
batch_pre = []
batch_nums = []
for sentent_class in batch_size_sentent_class:
sentents = sentent_class.sentence_batch
batch_pre.extend(sentents)
batch_nums.append(len(sentents))
with torch.no_grad():
outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device))
# input_pre = tokenizer(batch_pre, padding=True, return_tensors='pt').to(device)
# input_ids = input_pre['input_ids'].to(device)
# token_type_ids = input_pre["token_type_ids"].to(device)
# attention_mask = input_pre['attention_mask'].to(device)
# outputs = model(input_ids, token_type_ids, attention_mask)
outputs = model(**tokenizer(batch_pre, padding=True, return_tensors='pt').to(device))
result = []
print(outputs.logits)
for ids, text in zip(outputs.logits, texts):
batch_res = []
for ids,data_dan in zip(outputs.logits,batch_pre):
_text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
corrected_text = _text[:len(text)]
print(corrected_text)
corrected_text = _text[:len(data_dan)]
batch_res.append(corrected_text)
print(batch_pre)
print(batch_res)
batch_new = []
index = 0
for i in batch_nums:
index_new = index + i
batch_new.append(batch_res[index:index_new])
index = index_new
batch_pre_data = []
for dan, sentent_class in zip(batch_new, batch_size_sentent_class):
sentent_class.inf_ulit(dan)
batch_pre_data.append("".join(sentent_class.sentence_list))
result = []
for text, corrected_text in zip(texts,batch_pre_data):
corrected_text, details = get_errors(corrected_text, text)
result.append({"old": text,
"new": corrected_text,
"re_pos": details})
return result
@ -156,4 +190,4 @@ if __name__ == "__main__":
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式
)
app.run(host="0.0.0.0", port=16000, threaded=True, debug=False)
app.run(host="0.0.0.0", port=16001, threaded=True, debug=False)

Loading…
Cancel
Save