纠错任务
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

160 lines
5.5 KiB

2 years ago
import os
from flask import Flask, jsonify
from flask import request
import operator
import torch
from transformers import BertTokenizerFast, BertForMaskedLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import uuid
import json
from threading import Thread
import time
import re
import logging
import unicodedata
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
filename='rewrite.log',
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
# a是追加模式,默认如果不写的话,就是追加模式
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式
)
db_key_query = 'query'
batch_size = 32
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
import logging
pattern = r"[。]"
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”")
fuhao_end_sentence = ["", "", "", "", ""]
tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese")
model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese")
model.to(device)
def is_chinese(char):
if 'CJK' in unicodedata.name(char):
return True
else:
return False
class SentenceUlit:
def __init__(self, sentence):
self.sentence = sentence
self.sentence_list = [""] * len(sentence)
self.last_post = False
self.sentence_batch = []
self.pre_ulit()
self.inf_sentence_batch_str = ""
def is_chinese(self, char):
if 'CJK' in unicodedata.name(char):
return True
else:
return False
def pre_ulit(self):
for i, d in enumerate(self.sentence):
bool_ = is_chinese(d)
if bool_ == False:
self.sentence_list[i] = d
self.last_post = False
else:
if self.last_post == False:
self.sentence_batch.append(d)
else:
self.sentence_batch[-1] += d
self.last_post = True
def inf_ulit(self, sen):
for i in sen:
self.inf_sentence_batch_str += i
self.inf_sentence_batch_srt_list = list(self.inf_sentence_batch_str)
for i, d in enumerate(self.sentence_list):
if d == "":
zi = self.inf_sentence_batch_srt_list.pop(0)
self.sentence_list[i] = zi
class log:
def __init__(self):
pass
def log(*args, **kwargs):
format = '%Y/%m/%d-%H:%M:%S'
format_h = '%Y-%m-%d'
value = time.localtime(int(time.time()))
dt = time.strftime(format, value)
dt_log_file = time.strftime(format_h, value)
log_file = 'log_file/access-%s' % dt_log_file + ".log"
if not os.path.exists(log_file):
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
else:
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
def get_errors(corrected_text, origin_text):
sub_details = []
for i, ori_char in enumerate(origin_text):
if ori_char in [' ', '', '', '', '', '', '\n', '', '', '']:
# add unk word
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
continue
if i >= len(corrected_text):
continue
if ori_char != corrected_text[i]:
if ori_char.lower() == corrected_text[i]:
# pass english upper char
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
continue
sub_details.append((ori_char, corrected_text[i], i, i + 1))
sub_details = sorted(sub_details, key=operator.itemgetter(2))
return corrected_text, sub_details
def main(texts):
with torch.no_grad():
outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device))
result = []
print(outputs.logits)
for ids, text in zip(outputs.logits, texts):
_text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
corrected_text = _text[:len(text)]
print(corrected_text)
corrected_text, details = get_errors(corrected_text, text)
result.append({"old": text,
"new": corrected_text,
"re_pos": details})
return result
@app.route("/predict", methods=["POST"])
def handle_query():
print(request.remote_addr)
texts = request.json["texts"]
return_list = main(texts)
return_text = {"resilt": return_list, "probabilities": None, "status_code": 200}
return jsonify(return_text) # 返回结果
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
filename='rewrite.log',
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
# a是追加模式,默认如果不写的话,就是追加模式
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式
)
app.run(host="0.0.0.0", port=16000, threaded=True, debug=False)