You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
160 lines
5.5 KiB
160 lines
5.5 KiB
![]()
2 years ago
|
import os
|
||
|
from flask import Flask, jsonify
|
||
|
from flask import request
|
||
|
import operator
|
||
|
import torch
|
||
|
from transformers import BertTokenizerFast, BertForMaskedLM
|
||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
import uuid
|
||
|
import json
|
||
|
from threading import Thread
|
||
|
import time
|
||
|
import re
|
||
|
import logging
|
||
|
import unicodedata
|
||
|
|
||
|
|
||
|
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
|
||
|
filename='rewrite.log',
|
||
|
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
|
||
|
# a是追加模式,默认如果不写的话,就是追加模式
|
||
|
format=
|
||
|
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
||
|
# 日志格式
|
||
|
)
|
||
|
db_key_query = 'query'
|
||
|
batch_size = 32
|
||
|
app = Flask(__name__)
|
||
|
app.config["JSON_AS_ASCII"] = False
|
||
|
import logging
|
||
|
|
||
|
pattern = r"[。]"
|
||
|
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”")
|
||
|
fuhao_end_sentence = ["。", ",", "?", "!", "…"]
|
||
|
|
||
|
tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese")
|
||
|
model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese")
|
||
|
model.to(device)
|
||
|
|
||
|
|
||
|
def is_chinese(char):
|
||
|
if 'CJK' in unicodedata.name(char):
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
|
||
|
class SentenceUlit:
|
||
|
def __init__(self, sentence):
|
||
|
self.sentence = sentence
|
||
|
self.sentence_list = [""] * len(sentence)
|
||
|
self.last_post = False
|
||
|
self.sentence_batch = []
|
||
|
self.pre_ulit()
|
||
|
self.inf_sentence_batch_str = ""
|
||
|
|
||
|
def is_chinese(self, char):
|
||
|
if 'CJK' in unicodedata.name(char):
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
def pre_ulit(self):
|
||
|
for i, d in enumerate(self.sentence):
|
||
|
bool_ = is_chinese(d)
|
||
|
if bool_ == False:
|
||
|
self.sentence_list[i] = d
|
||
|
self.last_post = False
|
||
|
else:
|
||
|
if self.last_post == False:
|
||
|
self.sentence_batch.append(d)
|
||
|
else:
|
||
|
self.sentence_batch[-1] += d
|
||
|
self.last_post = True
|
||
|
|
||
|
def inf_ulit(self, sen):
|
||
|
for i in sen:
|
||
|
self.inf_sentence_batch_str += i
|
||
|
self.inf_sentence_batch_srt_list = list(self.inf_sentence_batch_str)
|
||
|
|
||
|
for i, d in enumerate(self.sentence_list):
|
||
|
if d == "":
|
||
|
zi = self.inf_sentence_batch_srt_list.pop(0)
|
||
|
self.sentence_list[i] = zi
|
||
|
|
||
|
|
||
|
class log:
|
||
|
def __init__(self):
|
||
|
pass
|
||
|
|
||
|
def log(*args, **kwargs):
|
||
|
format = '%Y/%m/%d-%H:%M:%S'
|
||
|
format_h = '%Y-%m-%d'
|
||
|
value = time.localtime(int(time.time()))
|
||
|
dt = time.strftime(format, value)
|
||
|
dt_log_file = time.strftime(format_h, value)
|
||
|
log_file = 'log_file/access-%s' % dt_log_file + ".log"
|
||
|
if not os.path.exists(log_file):
|
||
|
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
|
||
|
print(dt, *args, file=f, **kwargs)
|
||
|
else:
|
||
|
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
|
||
|
print(dt, *args, file=f, **kwargs)
|
||
|
|
||
|
def get_errors(corrected_text, origin_text):
|
||
|
sub_details = []
|
||
|
for i, ori_char in enumerate(origin_text):
|
||
|
if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
|
||
|
# add unk word
|
||
|
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
|
||
|
continue
|
||
|
if i >= len(corrected_text):
|
||
|
continue
|
||
|
if ori_char != corrected_text[i]:
|
||
|
if ori_char.lower() == corrected_text[i]:
|
||
|
# pass english upper char
|
||
|
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
|
||
|
continue
|
||
|
sub_details.append((ori_char, corrected_text[i], i, i + 1))
|
||
|
sub_details = sorted(sub_details, key=operator.itemgetter(2))
|
||
|
return corrected_text, sub_details
|
||
|
|
||
|
|
||
|
def main(texts):
|
||
|
with torch.no_grad():
|
||
|
outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device))
|
||
|
|
||
|
result = []
|
||
|
print(outputs.logits)
|
||
|
for ids, text in zip(outputs.logits, texts):
|
||
|
|
||
|
_text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
|
||
|
corrected_text = _text[:len(text)]
|
||
|
print(corrected_text)
|
||
|
corrected_text, details = get_errors(corrected_text, text)
|
||
|
result.append({"old": text,
|
||
|
"new": corrected_text,
|
||
|
"re_pos": details})
|
||
|
return result
|
||
|
|
||
|
|
||
|
@app.route("/predict", methods=["POST"])
|
||
|
def handle_query():
|
||
|
print(request.remote_addr)
|
||
|
texts = request.json["texts"]
|
||
|
return_list = main(texts)
|
||
|
return_text = {"resilt": return_list, "probabilities": None, "status_code": 200}
|
||
|
return jsonify(return_text) # 返回结果
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
|
||
|
filename='rewrite.log',
|
||
|
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
|
||
|
# a是追加模式,默认如果不写的话,就是追加模式
|
||
|
format=
|
||
|
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
||
|
# 日志格式
|
||
|
)
|
||
|
app.run(host="0.0.0.0", port=16000, threaded=True, debug=False)
|