import os from flask import Flask, jsonify from flask import request import operator import torch from transformers import BertTokenizerFast, BertForMaskedLM # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = "cpu" import uuid import json from threading import Thread import time import re import logging import unicodedata logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 filename='rewrite.log', filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 # a是追加模式,默认如果不写的话,就是追加模式 format= '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' # 日志格式 ) db_key_query = 'query' batch_size = 32 app = Flask(__name__) app.config["JSON_AS_ASCII"] = False import logging pattern = r"[。]" RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") fuhao_end_sentence = ["。", ",", "?", "!", "…"] tokenizer = BertTokenizerFast.from_pretrained("macbert4csc-base-chinese") model = BertForMaskedLM.from_pretrained("macbert4csc-base-chinese") model.to(device) def is_chinese(char): if 'CJK' in unicodedata.name(char): return True else: return False class SentenceUlit: def __init__(self, sentence): self.sentence = sentence self.sentence_list = [""] * len(sentence) self.last_post = False self.sentence_batch = [] self.pre_ulit() self.inf_sentence_batch_str = "" def is_chinese(self, char): if 'CJK' in unicodedata.name(char): return True else: return False def pre_ulit(self): for i, d in enumerate(self.sentence): bool_ = is_chinese(d) if bool_ == False: self.sentence_list[i] = d self.last_post = False else: if self.last_post == False: self.sentence_batch.append(d) else: self.sentence_batch[-1] += d self.last_post = True def inf_ulit(self, sen): for i in sen: self.inf_sentence_batch_str += i self.inf_sentence_batch_srt_list = list(self.inf_sentence_batch_str) for i, d in enumerate(self.sentence_list): if d == "": zi = self.inf_sentence_batch_srt_list.pop(0) self.sentence_list[i] = zi class log: def __init__(self): pass def log(*args, **kwargs): format = '%Y/%m/%d-%H:%M:%S' format_h = '%Y-%m-%d' value = time.localtime(int(time.time())) dt = time.strftime(format, value) dt_log_file = time.strftime(format_h, value) log_file = 'log_file/access-%s' % dt_log_file + ".log" if not os.path.exists(log_file): with open(os.path.join(log_file), 'w', encoding='utf-8') as f: print(dt, *args, file=f, **kwargs) else: with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: print(dt, *args, file=f, **kwargs) def get_errors(corrected_text, origin_text): sub_details = [] for i, ori_char in enumerate(origin_text): if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']: # add unk word corrected_text = corrected_text[:i] + ori_char + corrected_text[i:] continue if i >= len(corrected_text): continue if ori_char != corrected_text[i]: if ori_char.lower() == corrected_text[i]: # pass english upper char corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:] continue sub_details.append((ori_char, corrected_text[i], i, i + 1)) sub_details = sorted(sub_details, key=operator.itemgetter(2)) return corrected_text, sub_details def main(texts): batch_size_sentent_class = [] for text in texts: batch_size_sentent_class.append(SentenceUlit(text)) batch_pre = [] batch_nums = [] for sentent_class in batch_size_sentent_class: sentents = sentent_class.sentence_batch batch_pre.extend(sentents) batch_nums.append(len(sentents)) with torch.no_grad(): # input_pre = tokenizer(batch_pre, padding=True, return_tensors='pt').to(device) # input_ids = input_pre['input_ids'].to(device) # token_type_ids = input_pre["token_type_ids"].to(device) # attention_mask = input_pre['attention_mask'].to(device) # outputs = model(input_ids, token_type_ids, attention_mask) outputs = model(**tokenizer(batch_pre, padding=True, return_tensors='pt').to(device)) batch_res = [] for ids,data_dan in zip(outputs.logits,batch_pre): _text = tokenizer.decode(torch.argmax(ids, dim=-1), skip_special_tokens=True).replace(' ', '') corrected_text = _text[:len(data_dan)] batch_res.append(corrected_text) print(batch_pre) print(batch_res) batch_new = [] index = 0 for i in batch_nums: index_new = index + i batch_new.append(batch_res[index:index_new]) index = index_new batch_pre_data = [] for dan, sentent_class in zip(batch_new, batch_size_sentent_class): sentent_class.inf_ulit(dan) batch_pre_data.append("".join(sentent_class.sentence_list)) result = [] for text, corrected_text in zip(texts,batch_pre_data): corrected_text, details = get_errors(corrected_text, text) result.append({"old": text, "new": corrected_text, "re_pos": details}) return result @app.route("/predict", methods=["POST"]) def handle_query(): print(request.remote_addr) texts = request.json["texts"] return_list = main(texts) return_text = {"resilt": return_list, "probabilities": None, "status_code": 200} return jsonify(return_text) # 返回结果 if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 filename='rewrite.log', filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 # a是追加模式,默认如果不写的话,就是追加模式 format= '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' # 日志格式 ) app.run(host="0.0.0.0", port=16001, threaded=True, debug=False)