纠错任务
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
5.5 KiB

import os
from flask import Flask, jsonify
from flask import request
import operator
import torch
from transformers import BertTokenizerFast, BertForMaskedLM
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import uuid
import json
from threading import Thread
import time
import re
import logging
import unicodedata
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
filename='rewrite.log',
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
# a是追加模式,默认如果不写的话,就是追加模式
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式
)
db_key_query = 'query'
batch_size = 32
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
import logging
pattern = r"[。]"
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”")
fuhao_end_sentence = ["", "", "", "", ""]
tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese")
model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese")
model.to(device)
def is_chinese(char):
if 'CJK' in unicodedata.name(char):
return True
else:
return False
class SentenceUlit:
def __init__(self, sentence):
self.sentence = sentence
self.sentence_list = [""] * len(sentence)
self.last_post = False
self.sentence_batch = []
self.pre_ulit()
self.inf_sentence_batch_str = ""
def is_chinese(self, char):
if 'CJK' in unicodedata.name(char):
return True
else:
return False
def pre_ulit(self):
for i, d in enumerate(self.sentence):
bool_ = is_chinese(d)
if bool_ == False:
self.sentence_list[i] = d
self.last_post = False
else:
if self.last_post == False:
self.sentence_batch.append(d)
else:
self.sentence_batch[-1] += d
self.last_post = True
def inf_ulit(self, sen):
for i in sen:
self.inf_sentence_batch_str += i
self.inf_sentence_batch_srt_list = list(self.inf_sentence_batch_str)
for i, d in enumerate(self.sentence_list):
if d == "":
zi = self.inf_sentence_batch_srt_list.pop(0)
self.sentence_list[i] = zi
class log:
def __init__(self):
pass
def log(*args, **kwargs):
format = '%Y/%m/%d-%H:%M:%S'
format_h = '%Y-%m-%d'
value = time.localtime(int(time.time()))
dt = time.strftime(format, value)
dt_log_file = time.strftime(format_h, value)
log_file = 'log_file/access-%s' % dt_log_file + ".log"
if not os.path.exists(log_file):
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
else:
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
def get_errors(corrected_text, origin_text):
sub_details = []
for i, ori_char in enumerate(origin_text):
if ori_char in [' ', '', '', '', '', '', '\n', '', '', '']:
# add unk word
corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
continue
if i >= len(corrected_text):
continue
if ori_char != corrected_text[i]:
if ori_char.lower() == corrected_text[i]:
# pass english upper char
corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
continue
sub_details.append((ori_char, corrected_text[i], i, i + 1))
sub_details = sorted(sub_details, key=operator.itemgetter(2))
return corrected_text, sub_details
def main(texts):
with torch.no_grad():
outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device))
result = []
print(outputs.logits)
for ids, text in zip(outputs.logits, texts):
_text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '')
corrected_text = _text[:len(text)]
print(corrected_text)
corrected_text, details = get_errors(corrected_text, text)
result.append({"old": text,
"new": corrected_text,
"re_pos": details})
return result
@app.route("/predict", methods=["POST"])
def handle_query():
print(request.remote_addr)
texts = request.json["texts"]
return_list = main(texts)
return_text = {"resilt": return_list, "probabilities": None, "status_code": 200}
return jsonify(return_text) # 返回结果
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
filename='rewrite.log',
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
# a是追加模式,默认如果不写的话,就是追加模式
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式
)
app.run(host="0.0.0.0", port=16000, threaded=True, debug=False)