You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
193 lines
6.6 KiB
193 lines
6.6 KiB
import os
|
|
from flask import Flask, jsonify
|
|
from flask import request
|
|
import operator
|
|
import torch
|
|
from transformers import BertTokenizerFast, BertForMaskedLM
|
|
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
device = "cpu"
|
|
import uuid
|
|
import json
|
|
from threading import Thread
|
|
import time
|
|
import re
|
|
import logging
|
|
import unicodedata
|
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
|
|
filename='rewrite.log',
|
|
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
|
|
# a是追加模式,默认如果不写的话,就是追加模式
|
|
format=
|
|
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
|
# 日志格式
|
|
)
|
|
db_key_query = 'query'
|
|
batch_size = 32
|
|
app = Flask(__name__)
|
|
app.config["JSON_AS_ASCII"] = False
|
|
import logging
|
|
|
|
pattern = r"[。]"
|
|
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”")
|
|
fuhao_end_sentence = ["。", ",", "?", "!", "…"]
|
|
|
|
tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese")
|
|
model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese")
|
|
model.to(device)
|
|
|
|
|
|
def is_chinese(char):
|
|
if 'CJK' in unicodedata.name(char):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
class SentenceUlit:
|
|
def __init__(self, sentence):
|
|
self.sentence = sentence
|
|
self.sentence_list = [""] * len(sentence)
|
|
self.last_post = False
|
|
self.sentence_batch = []
|
|
self.pre_ulit()
|
|
self.inf_sentence_batch_str = ""
|
|
|
|
def is_chinese(self, char):
|
|
if 'CJK' in unicodedata.name(char):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def pre_ulit(self):
|
|
for i, d in enumerate(self.sentence):
|
|
bool_ = is_chinese(d)
|
|
if bool_ == False:
|
|
self.sentence_list[i] = d
|
|
self.last_post = False
|
|
else:
|
|
if self.last_post == False:
|
|
self.sentence_batch.append(d)
|
|
else:
|
|
self.sentence_batch[-1] += d
|
|
self.last_post = True
|
|
|
|
def inf_ulit(self, sen):
|
|
for i in sen:
|
|
self.inf_sentence_batch_str += i
|
|
self.inf_sentence_batch_srt_list = list(self.inf_sentence_batch_str)
|
|
|
|
for i, d in enumerate(self.sentence_list):
|
|
if d == "":
|
|
zi = self.inf_sentence_batch_srt_list.pop(0)
|
|
self.sentence_list[i] = zi
|
|
|
|
|
|
class log:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def log(*args, **kwargs):
|
|
format = '%Y/%m/%d-%H:%M:%S'
|
|
format_h = '%Y-%m-%d'
|
|
value = time.localtime(int(time.time()))
|
|
dt = time.strftime(format, value)
|
|
dt_log_file = time.strftime(format_h, value)
|
|
log_file = 'log_file/access-%s' % dt_log_file + ".log"
|
|
if not os.path.exists(log_file):
|
|
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
|
|
print(dt, *args, file=f, **kwargs)
|
|
else:
|
|
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
|
|
print(dt, *args, file=f, **kwargs)
|
|
|
|
def get_errors(corrected_text, origin_text):
|
|
sub_details = []
|
|
for i, ori_char in enumerate(origin_text):
|
|
if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
|
|
# add unk word
|
|
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
|
|
continue
|
|
if i >= len(corrected_text):
|
|
continue
|
|
if ori_char != corrected_text[i]:
|
|
if ori_char.lower() == corrected_text[i]:
|
|
# pass english upper char
|
|
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
|
|
continue
|
|
sub_details.append((ori_char, corrected_text[i], i, i + 1))
|
|
sub_details = sorted(sub_details, key=operator.itemgetter(2))
|
|
return corrected_text, sub_details
|
|
|
|
|
|
def main(texts):
|
|
batch_size_sentent_class = []
|
|
|
|
for text in texts:
|
|
batch_size_sentent_class.append(SentenceUlit(text))
|
|
|
|
batch_pre = []
|
|
batch_nums = []
|
|
for sentent_class in batch_size_sentent_class:
|
|
sentents = sentent_class.sentence_batch
|
|
batch_pre.extend(sentents)
|
|
batch_nums.append(len(sentents))
|
|
|
|
with torch.no_grad():
|
|
# input_pre = tokenizer(batch_pre, padding=True, return_tensors='pt').to(device)
|
|
# input_ids = input_pre['input_ids'].to(device)
|
|
# token_type_ids = input_pre["token_type_ids"].to(device)
|
|
# attention_mask = input_pre['attention_mask'].to(device)
|
|
# outputs = model(input_ids, token_type_ids, attention_mask)
|
|
outputs = model(**tokenizer(batch_pre, padding=True, return_tensors='pt').to(device))
|
|
|
|
batch_res = []
|
|
|
|
for ids,data_dan in zip(outputs.logits,batch_pre):
|
|
_text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
|
|
corrected_text = _text[:len(data_dan)]
|
|
batch_res.append(corrected_text)
|
|
print(batch_pre)
|
|
print(batch_res)
|
|
batch_new = []
|
|
index = 0
|
|
for i in batch_nums:
|
|
index_new = index + i
|
|
batch_new.append(batch_res[index:index_new])
|
|
index = index_new
|
|
|
|
batch_pre_data = []
|
|
for dan, sentent_class in zip(batch_new, batch_size_sentent_class):
|
|
sentent_class.inf_ulit(dan)
|
|
batch_pre_data.append("".join(sentent_class.sentence_list))
|
|
|
|
result = []
|
|
for text, corrected_text in zip(texts,batch_pre_data):
|
|
corrected_text, details = get_errors(corrected_text, text)
|
|
result.append({"old": text,
|
|
"new": corrected_text,
|
|
"re_pos": details})
|
|
|
|
return result
|
|
|
|
|
|
@app.route("/predict", methods=["POST"])
|
|
def handle_query():
|
|
print(request.remote_addr)
|
|
texts = request.json["texts"]
|
|
return_list = main(texts)
|
|
return_text = {"resilt": return_list, "probabilities": None, "status_code": 200}
|
|
return jsonify(return_text) # 返回结果
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
|
|
filename='rewrite.log',
|
|
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
|
|
# a是追加模式,默认如果不写的话,就是追加模式
|
|
format=
|
|
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
|
# 日志格式
|
|
)
|
|
app.run(host="0.0.0.0", port=16001, threaded=True, debug=False)
|
|
|