diff --git a/flask_macbert.py b/flask_macbert.py index 818afb4..a3f0bf9 100644 --- a/flask_macbert.py +++ b/flask_macbert.py @@ -4,7 +4,8 @@ from flask import request import operator import torch from transformers import BertTokenizerFast, BertForMaskedLM -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = "cpu" import uuid import json from threading import Thread @@ -121,20 +122,53 @@ def get_errors(corrected_text, origin_text): def main(texts): + batch_size_sentent_class = [] + + for text in texts: + batch_size_sentent_class.append(SentenceUlit(text)) + + batch_pre = [] + batch_nums = [] + for sentent_class in batch_size_sentent_class: + sentents = sentent_class.sentence_batch + batch_pre.extend(sentents) + batch_nums.append(len(sentents)) + with torch.no_grad(): - outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device)) + # input_pre = tokenizer(batch_pre, padding=True, return_tensors='pt').to(device) + # input_ids = input_pre['input_ids'].to(device) + # token_type_ids = input_pre["token_type_ids"].to(device) + # attention_mask = input_pre['attention_mask'].to(device) + # outputs = model(input_ids, token_type_ids, attention_mask) + outputs = model(**tokenizer(batch_pre, padding=True, return_tensors='pt').to(device)) - result = [] - print(outputs.logits) - for ids, text in zip(outputs.logits, texts): + batch_res = [] + for ids,data_dan in zip(outputs.logits,batch_pre): _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '') - corrected_text = _text[:len(text)] - print(corrected_text) + corrected_text = _text[:len(data_dan)] + batch_res.append(corrected_text) + print(batch_pre) + print(batch_res) + batch_new = [] + index = 0 + for i in batch_nums: + index_new = index + i + batch_new.append(batch_res[index:index_new]) + index = index_new + + batch_pre_data = [] + for dan, sentent_class in zip(batch_new, batch_size_sentent_class): + sentent_class.inf_ulit(dan) + batch_pre_data.append("".join(sentent_class.sentence_list)) + + result = [] + for text, corrected_text in zip(texts,batch_pre_data): corrected_text, details = get_errors(corrected_text, text) result.append({"old": text, "new": corrected_text, "re_pos": details}) + return result @@ -156,4 +190,4 @@ if __name__ == "__main__": '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' # 日志格式 ) - app.run(host="0.0.0.0", port=16000, threaded=True, debug=False) + app.run(host="0.0.0.0", port=16001, threaded=True, debug=False)