import os from flask import Flask, jsonify from flask import request import operator import torch from transformers import BertTokenizerFast, BertForMaskedLM device = torch.device("cuda" if torch.cuda.is_available() else "cpu") import uuid import json from threading import Thread import time import re import logging import unicodedata logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 filename='rewrite.log', filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 # a是追加模式,默认如果不写的话,就是追加模式 format= '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' # 日志格式 ) db_key_query = 'query' batch_size = 32 app = Flask(__name__) app.config["JSON_AS_ASCII"] = False import logging pattern = r"[。]" RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") fuhao_end_sentence = ["。", ",", "?", "!", "…"] tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese") model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese") model.to(device) def is_chinese(char): if 'CJK' in unicodedata.name(char): return True else: return False class SentenceUlit: def __init__(self, sentence): self.sentence = sentence self.sentence_list = [""] * len(sentence) self.last_post = False self.sentence_batch = [] self.pre_ulit() self.inf_sentence_batch_str = "" def is_chinese(self, char): if 'CJK' in unicodedata.name(char): return True else: return False def pre_ulit(self): for i, d in enumerate(self.sentence): bool_ = is_chinese(d) if bool_ == False: self.sentence_list[i] = d self.last_post = False else: if self.last_post == False: self.sentence_batch.append(d) else: self.sentence_batch[-1] += d self.last_post = True def inf_ulit(self, sen): for i in sen: self.inf_sentence_batch_str += i self.inf_sentence_batch_srt_list = list(self.inf_sentence_batch_str) for i, d in enumerate(self.sentence_list): if d == "": zi = self.inf_sentence_batch_srt_list.pop(0) self.sentence_list[i] = zi class log: def __init__(self): pass def log(*args, **kwargs): format = '%Y/%m/%d-%H:%M:%S' format_h = '%Y-%m-%d' value = time.localtime(int(time.time())) dt = time.strftime(format, value) dt_log_file = time.strftime(format_h, value) log_file = 'log_file/access-%s' % dt_log_file + ".log" if not os.path.exists(log_file): with open(os.path.join(log_file), 'w', encoding='utf-8') as f: print(dt, *args, file=f, **kwargs) else: with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: print(dt, *args, file=f, **kwargs) def get_errors(corrected_text, origin_text): sub_details = [] for i, ori_char in enumerate(origin_text): if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']: # add unk word corrected_text = corrected_text[:i] + ori_char + corrected_text[i:] continue if i >= len(corrected_text): continue if ori_char != corrected_text[i]: if ori_char.lower() == corrected_text[i]: # pass english upper char corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] continue sub_details.append((ori_char, corrected_text[i], i, i + 1)) sub_details = sorted(sub_details, key=operator.itemgetter(2)) return corrected_text, sub_details def main(texts): with torch.no_grad(): outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device)) result = [] print(outputs.logits) for ids, text in zip(outputs.logits, texts): _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '') corrected_text = _text[:len(text)] print(corrected_text) corrected_text, details = get_errors(corrected_text, text) result.append({"old": text, "new": corrected_text, "re_pos": details}) return result @app.route("/predict", methods=["POST"]) def handle_query(): print(request.remote_addr) texts = request.json["texts"] return_list = main(texts) return_text = {"resilt": return_list, "probabilities": None, "status_code": 200} return jsonify(return_text) # 返回结果 if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 filename='rewrite.log', filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 # a是追加模式,默认如果不写的话,就是追加模式 format= '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' # 日志格式 ) app.run(host="0.0.0.0", port=16000, threaded=True, debug=False)