降重增强版
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

413 lines
15 KiB

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from flask import Flask, jsonify
from flask import request
import requests
import redis
import uuid
import json
from threading import Thread
import time
import re
import logging
from vllm import LLM, SamplingParams
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
filename='rewrite.log',
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
# a是追加模式,默认如果不写的话,就是追加模式
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式
)
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=7, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query'
db_key_querying = 'querying'
db_key_queryset = 'queryset'
batch_size = 32
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
import logging
pattern = r"[。]"
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”")
fuhao_end_sentence = ["", "", "", "", ""]
# 加载模型
sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=1.1,stop="</s>", max_tokens=4096)
models_path = "/home/majiahui/model-llm/openbuddy-mistral-7b-v13.1"
llm = LLM(model=models_path, tokenizer_mode="slow")
class log:
def __init__(self):
pass
def log(*args, **kwargs):
format = '%Y/%m/%d-%H:%M:%S'
format_h = '%Y-%m-%d'
value = time.localtime(int(time.time()))
dt = time.strftime(format, value)
dt_log_file = time.strftime(format_h, value)
log_file = 'log_file/access-%s' % dt_log_file + ".log"
if not os.path.exists(log_file):
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
else:
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
def get_dialogs_index(line: str):
"""
获取对话及其索引
:param line 文本
:return dialogs 对话内容
dialogs_index: 对话位置索引
other_index: 其他内容位置索引
"""
dialogs = re.finditer(RE_DIALOG, line)
dialogs_text = re.findall(RE_DIALOG, line)
dialogs_index = []
for dialog in dialogs:
all_ = [i for i in range(dialog.start(), dialog.end())]
dialogs_index.extend(all_)
other_index = [i for i in range(len(line)) if i not in dialogs_index]
return dialogs_text, dialogs_index, other_index
def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num):
fuhao = ["", "", "", ""]
dialogs_text, dialogs_index, other_index = get_dialogs_index(text)
text_1 = text[:120]
text_2 = text[120:]
text_1_new = ""
if text_2 == "":
chulipangban_return_list.append([text_1, snetence_id, short_num])
return chulipangban_return_list
for i in range(len(text_1) - 1, -1, -1):
if text_1[i] in fuhao:
if i in dialogs_index:
continue
text_1_new = text_1[:i]
text_1_new += text_1[i]
chulipangban_return_list.append([text_1_new, snetence_id, short_num])
if text_2 != "":
if i + 1 != 120:
text_2 = text_1[i + 1:] + text_2
break
# else:
# chulipangban_return_list.append(text_1)
if text_1_new == "":
chulipangban_return_list.append([text_1, snetence_id, short_num])
if text_2 != "":
short_num += 1
chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num)
return chulipangban_return_list
def chulipangban_test_1(snetence_id, text):
# 引号处理
dialogs_text, dialogs_index, other_index = get_dialogs_index(text)
for dialogs_text_dan in dialogs_text:
text_dan_list = text.split(dialogs_text_dan)
text = dialogs_text_dan.join(text_dan_list)
# text_new_str = "".join(text_new)
sentence_list = text.split("")
# sentence_list_new = []
# for i in sentence_list:
# if i != "":
# sentence_list_new.append(i)
# sentence_list = sentence_list_new
sentence_batch_list = []
sentence_batch_one = []
sentence_batch_length = 0
return_list = []
for sentence in sentence_list[:-1]:
if len(sentence) < 120:
sentence_batch_length += len(sentence)
sentence_batch_list.append([sentence + "", snetence_id, 0])
# sentence_pre = autotitle.gen_synonyms_short(sentence)
# return_list.append(sentence_pre)
else:
sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0)
for sentence_short in sentence_split_list[:-1]:
sentence_batch_list.append(sentence_short)
sentence_batch_list.append(sentence_split_list[-1] + "")
if sentence_list[:-1] != "":
if len(sentence_list[-1]) < 120:
sentence_batch_length += len(sentence_list[-1])
sentence_batch_list.append([sentence_list[-1], snetence_id, 0])
# sentence_pre = autotitle.gen_synonyms_short(sentence)
# return_list.append(sentence_pre)
else:
sentence_split_list = chulichangju_1(sentence_list[-1], snetence_id, [], 0)
for sentence_short in sentence_split_list:
sentence_batch_list.append(sentence_short)
return sentence_batch_list
def paragraph_test(texts: dict):
text_new = []
for i, text in texts.items():
text_list = chulipangban_test_1(i, text)
text_new.extend(text_list)
# text_new_str = "".join(text_new)
return text_new
def batch_data_process(text_list):
sentence_batch_length = 0
sentence_batch_one = []
sentence_batch_list = []
for sentence in text_list:
sentence_batch_length += len(sentence[0])
sentence_batch_one.append(sentence)
if sentence_batch_length > 500:
sentence_batch_length = 0
sentence_ = sentence_batch_one.pop(-1)
sentence_batch_list.append(sentence_batch_one)
sentence_batch_one = []
sentence_batch_one.append(sentence_)
sentence_batch_list.append(sentence_batch_one)
return sentence_batch_list
def batch_predict(batch_data_list):
'''
一个bacth数据预测
@param data_text:
@return:
'''
batch_data_list_new = []
batch_data_text_list = []
batch_data_snetence_id_list = []
for i in batch_data_list:
batch_data_text_list.append(i[0])
batch_data_snetence_id_list.append(i[1:])
# batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list)
batch_pre_data_list = batch_data_text_list
for text, sentence_id in zip(batch_pre_data_list, batch_data_snetence_id_list):
batch_data_list_new.append([text] + sentence_id)
return batch_data_list_new
def predict_data_post_processing(text_list):
text_list_sentence = []
# text_list_sentence.append([text_list[0][0], text_list[0][1]])
for i in range(len(text_list)):
if text_list[i][2] != 0:
text_list_sentence[-1][0] += text_list[i][0]
else:
text_list_sentence.append([text_list[i][0], text_list[i][1]])
return_list = {}
sentence_one = []
sentence_id = text_list_sentence[0][1]
for i in text_list_sentence:
if i[1] == sentence_id:
sentence_one.append(i[0])
else:
return_list[sentence_id] = "".join(sentence_one)
sentence_id = i[1]
sentence_one = []
sentence_one.append(i[0])
if sentence_one != []:
return_list[sentence_id] = "".join(sentence_one)
return return_list
# def main(text:list):
# # text_list = paragraph_test(text)
# # batch_data = batch_data_process(text_list)
# # text_list = []
# # for i in batch_data:
# # text_list.extend(i)
# # return_list = predict_data_post_processing(text_list)
# # return return_list
def pre_sentence_ulit(sentence):
if "改写后:" in sentence:
sentence_lable_index = sentence.index("改写后:")
sentence = sentence[sentence_lable_index + 4:]
return sentence
def main(texts: dict):
text_list = paragraph_test(texts)
text_info = []
text_sentence = []
text_list_new = []
# for i in text_list:
# pre = one_predict(i)
# text_list_new.append(pre)
# vllm预测
for i in text_list:
if len(i[0]) > 7:
text = "You are a helpful assistant.\n\nUser:改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{}\nAssistant:".format(i[0])
else:
text = "You are a helpful assistant.\n\nUser:下面词不做任何变化:\n{}\nAssistant:".format(i[0])
text_sentence.append(text)
text_info.append([i[1], i[2]])
outputs = llm.generate(text_sentence, sampling_params) # 调用模型
generated_text_list = [""] * len(text_sentence)
# generated_text_list = ["" if len(i[0]) > 5 else i[0] for i in text_list]
for i, output in enumerate(outputs):
index = output.request_id
generated_text = output.outputs[0].text
generated_text_list[int(index)] = generated_text
for i in range(len(text_list)):
if len(text_list[i][0]) > 7:
generated_text_list[i] = pre_sentence_ulit(generated_text_list[i])
else:
generated_text_list[i] = text_list[i][0]
for i, j in zip(generated_text_list, text_info):
text_list_new.append([i] + j)
return_list = predict_data_post_processing(text_list_new)
return return_list
# @app.route('/droprepeat/', methods=['POST'])
# def sentence():
# print(request.remote_addr)
# texts = request.json["texts"]
# text_type = request.json["text_type"]
# print("原始语句" + str(texts))
# # question = question.strip('。、!??')
#
# if isinstance(texts, dict):
# texts_list = []
# y_pred_label_list = []
# position_list = []
#
# # texts = texts.replace('\'', '\"')
# if texts is None:
# return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False}
# return jsonify(return_text)
# else:
# assert text_type in ['focus', 'chapter']
# if text_type == 'focus':
# texts_list = main(texts)
# if text_type == 'chapter':
# texts_list = main(texts)
# return_text = {"texts": texts_list, "probabilities": None, "status_code": True}
# else:
# return_text = {"texts": "输入格式应该为list", "probabilities": None, "status_code": False}
# return jsonify(return_text)
def classify(): # 调用模型,设置最大batch_size
while True:
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
time.sleep(3)
continue
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text
data_dict_path = json.loads(query)
path = data_dict_path['path']
# text_type = data_dict["text_type"]
with open(path, encoding='utf8') as f1:
# 加载文件的对象
data_dict = json.load(f1)
query_id = data_dict['id']
texts = data_dict["text"]
text_type = data_dict["text_type"]
assert text_type in ['focus', 'chapter']
if text_type == 'focus':
texts_list = main(texts)
elif text_type == 'chapter':
texts_list = main(texts)
else:
texts_list = []
return_text = {"texts": texts_list, "probabilities": None, "status_code": 200}
load_result_path = "./new_data_logs/{}.json".format(query_id)
print("query_id: ", query_id)
print("load_result_path: ", load_result_path)
with open(load_result_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(return_text, f2, ensure_ascii=False, indent=4)
debug_id_1 = 1
redis_.set(query_id, load_result_path, 86400)
debug_id_2 = 2
redis_.srem(db_key_querying, query_id)
debug_id_3 = 3
log.log('start at',
'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format(
query_id, load_result_path, return_text, debug_id_1, debug_id_2, debug_id_3))
@app.route("/predict", methods=["POST"])
def handle_query():
print(request.remote_addr)
texts = request.json["texts"]
text_type = request.json["text_type"]
if texts is None:
return_text = {"texts": "输入了空值", "probabilities": None, "status_code": 402}
return jsonify(return_text)
if isinstance(texts, dict):
id_ = str(uuid.uuid1()) # 为query生成唯一标识
print("uuid: ", uuid)
d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id
load_request_path = './request_data_logs/{}.json'.format(id_)
with open(load_request_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(d, f2, ensure_ascii=False, indent=4)
redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis
redis_.sadd(db_key_querying, id_)
redis_.sadd(db_key_queryset, id_)
return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200}
print("ok")
else:
return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 401}
return jsonify(return_text) # 返回结果
t = Thread(target=classify)
t.start()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
filename='rewrite.log',
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
# a是追加模式,默认如果不写的话,就是追加模式
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式
)
app.run(host="0.0.0.0", port=14002, threaded=True, debug=False)