You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
821 lines
30 KiB
821 lines
30 KiB
import os
|
|
from flask import Flask, jsonify
|
|
from flask import request
|
|
import requests
|
|
import redis
|
|
import uuid
|
|
import json
|
|
from threading import Thread
|
|
import time
|
|
import re
|
|
import logging
|
|
import concurrent.futures
|
|
import socket
|
|
from sentence_spliter.logic_graph_en import long_cuter_en
|
|
from sentence_spliter.automata.state_machine import StateMachine
|
|
from sentence_spliter.automata.sequence import EnSequence # 调取英文 Sequence
|
|
from openai import OpenAI
|
|
import os
|
|
import concurrent.futures
|
|
import traceback
|
|
|
|
# 初始化OpenAI客户端
|
|
client = OpenAI(
|
|
# 如果没有配置环境变量,请用百炼API Key替换:api_key="sk-xxx"
|
|
api_key="sk-b05ca5cefaf348c6a78443954472c1e4",
|
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
)
|
|
|
|
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
|
|
filename='rewrite.log',
|
|
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
|
|
# a是追加模式,默认如果不写的话,就是追加模式
|
|
format=
|
|
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
|
# 日志格式
|
|
)
|
|
|
|
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=18, password="zhicheng123*")
|
|
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
|
|
|
|
db_key_query = 'query'
|
|
db_key_querying = 'querying'
|
|
batch_size = 20
|
|
|
|
app = Flask(__name__)
|
|
app.config["JSON_AS_ASCII"] = False
|
|
|
|
import logging
|
|
|
|
pattern = r"[。]"
|
|
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”")
|
|
fuhao_end_sentence = ["。", ",", "?", "!", "…"]
|
|
pantten_biaoti_0 = '^[1-9一二三四五六七八九ⅠⅡⅢⅣⅤⅥⅦⅧⅨ][、.]\s{0,}?[\u4e00-\u9fa5a-zA-Z]+'
|
|
pantten_biaoti_1 = '^第[一二三四五六七八九]章\s{0,}?[\u4e00-\u9fa5a-zA-Z]+'
|
|
pantten_biaoti_2 = '^[0-9.]+\s{0,}?[\u4e00-\u9fa5a-zA-Z]+'
|
|
pantten_biaoti_3 = '^[((][1-9一二三四五六七八九ⅠⅡⅢⅣⅤⅥⅦⅧⅨ][)_)][、.]{0,}?\s{0,}?[\u4e00-\u9fa5a-zA-Z]+'
|
|
pantten_biaoti_4 = '(摘要)'
|
|
pantten_biaoti_5 = '(致谢)'
|
|
|
|
|
|
def get_host_ip():
|
|
"""
|
|
查询本机ip地址
|
|
:return: ip
|
|
"""
|
|
try:
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
s.connect(('8.8.8.8', 80))
|
|
ip = s.getsockname()[0]
|
|
finally:
|
|
s.close()
|
|
|
|
return ip
|
|
|
|
|
|
# chatgpt_url_predict = "http://{}:12001/predict".format(str(get_host_ip()))
|
|
# chatgpt_url_search = "http://{}:12001/search".format(str(get_host_ip()))
|
|
chatgpt_url_predict = "http://{}:26000/predict".format(str(get_host_ip()))
|
|
chatgpt_url_search = "http://{}:26000/search".format(str(get_host_ip()))
|
|
|
|
|
|
def smtp_f(name):
|
|
# 在下面的代码行中使用断点来调试脚本。
|
|
import smtplib
|
|
from email.mime.text import MIMEText
|
|
from email.header import Header
|
|
|
|
sender = '838878981@qq.com' # 发送邮箱
|
|
receivers = ['838878981@qq.com'] # 接收邮箱
|
|
auth_code = "jfqtutaiwrtdbcge" # 授权码
|
|
|
|
message = MIMEText('降重项目出错,紧急', 'plain', 'utf-8')
|
|
message['From'] = Header("Sender<%s>" % sender) # 发送者
|
|
message['To'] = Header("Receiver<%s>" % receivers[0]) # 接收者
|
|
|
|
subject = name
|
|
message['Subject'] = Header(subject, 'utf-8')
|
|
|
|
try:
|
|
server = smtplib.SMTP_SSL('smtp.qq.com', 465)
|
|
server.login(sender, auth_code)
|
|
server.sendmail(sender, receivers, message.as_string())
|
|
print("邮件发送成功")
|
|
server.close()
|
|
except smtplib.SMTPException:
|
|
print("Error: 无法发送邮件")
|
|
|
|
|
|
class log:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def log(*args, **kwargs):
|
|
format = '%Y/%m/%d-%H:%M:%S'
|
|
format_h = '%Y-%m-%d'
|
|
value = time.localtime(int(time.time()))
|
|
dt = time.strftime(format, value)
|
|
dt_log_file = time.strftime(format_h, value)
|
|
log_file = 'log_file/access-%s' % dt_log_file + ".log"
|
|
if not os.path.exists(log_file):
|
|
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
|
|
print(dt, *args, file=f, **kwargs)
|
|
else:
|
|
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
|
|
print(dt, *args, file=f, **kwargs)
|
|
|
|
|
|
def dialog_line_parse(url, text):
|
|
"""
|
|
将数据输入模型进行分析并输出结果
|
|
:param url: 模型url
|
|
:param text: 进入模型的数据
|
|
:return: 模型返回结果
|
|
"""
|
|
|
|
response = requests.post(
|
|
url,
|
|
json=text,
|
|
timeout=100000
|
|
)
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
# logger.error(
|
|
# "【{}】 Failed to get a proper response from remote "
|
|
# "server. Status Code: {}. Response: {}"
|
|
# "".format(url, response.status_code, response.text)
|
|
# )
|
|
print("【{}】 Failed to get a proper response from remote "
|
|
"server. Status Code: {}. Response: {}"
|
|
"".format(url, response.status_code, response.text))
|
|
return {}
|
|
|
|
|
|
def get_dialogs_index(line: str):
|
|
"""
|
|
获取对话及其索引
|
|
:param line 文本
|
|
:return dialogs 对话内容
|
|
dialogs_index: 对话位置索引
|
|
other_index: 其他内容位置索引
|
|
"""
|
|
dialogs = re.finditer(RE_DIALOG, line)
|
|
dialogs_text = re.findall(RE_DIALOG, line)
|
|
dialogs_index = []
|
|
for dialog in dialogs:
|
|
all_ = [i for i in range(dialog.start(), dialog.end())]
|
|
dialogs_index.extend(all_)
|
|
other_index = [i for i in range(len(line)) if i not in dialogs_index]
|
|
|
|
return dialogs_text, dialogs_index, other_index
|
|
|
|
|
|
def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num, sen_len):
|
|
fuhao = [",", "?", "!", "…"]
|
|
dialogs_text, dialogs_index, other_index = get_dialogs_index(text)
|
|
text_1 = text[:sen_len]
|
|
text_2 = text[sen_len:]
|
|
text_1_new = ""
|
|
if text_2 == "":
|
|
chulipangban_return_list.append([text_1, snetence_id, short_num])
|
|
return chulipangban_return_list
|
|
for i in range(len(text_1) - 1, -1, -1):
|
|
if text_1[i] in fuhao:
|
|
if i in dialogs_index:
|
|
continue
|
|
text_1_new = text_1[:i]
|
|
text_1_new += text_1[i]
|
|
chulipangban_return_list.append([text_1_new, snetence_id, short_num])
|
|
if text_2 != "":
|
|
if i + 1 != sen_len:
|
|
text_2 = text_1[i + 1:] + text_2
|
|
break
|
|
# else:
|
|
# chulipangban_return_list.append(text_1)
|
|
if text_1_new == "":
|
|
chulipangban_return_list.append([text_1, snetence_id, short_num])
|
|
if text_2 != "":
|
|
short_num += 1
|
|
chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num, sen_len)
|
|
return chulipangban_return_list
|
|
|
|
|
|
# def get_multiple_urls(urls):
|
|
# with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
# future_to_url = {executor.submit(dialog_line_parse, url[1], url[2]): url for url in urls}
|
|
#
|
|
#
|
|
# results = []
|
|
# for future in concurrent.futures.as_completed(future_to_url):
|
|
# url = future_to_url[future]
|
|
# try:
|
|
# data = future.result()
|
|
# results.append((url, data))
|
|
# except Exception as e:
|
|
# results.append((url, f"Error: {str(e)}"))
|
|
# return results
|
|
|
|
# def request_api_chatgpt(prompt):
|
|
# data = {
|
|
# "content": prompt,
|
|
# "model": "gpt-4-turbo",
|
|
# "top_p": 0.7,
|
|
# "temperature": 0.6
|
|
# }
|
|
# response = requests.post(
|
|
# chatgpt_url_predict,
|
|
# json=data,
|
|
# timeout=100000
|
|
# )
|
|
# if response.status_code == 200:
|
|
# return response.json()
|
|
# else:
|
|
# # logger.error(
|
|
# # "【{}】 Failed to get a proper response from remote "
|
|
# # "server. Status Code: {}. Response: {}"
|
|
# # "".format(url, response.status_code, response.text)
|
|
# # )
|
|
# print("Failed to get a proper response from remote "
|
|
# "server. Status Code: {}. Response: {}"
|
|
# "".format(response.status_code, response.text))
|
|
# return {}
|
|
|
|
|
|
def request_api_chatgpt(prompt, prompt2):
|
|
try:
|
|
reasoning_content = "" # 定义完整思考过程
|
|
answer_content = "" # 定义完整回复
|
|
is_answering = False # 判断是否结束思考过程并开始回复
|
|
# 创建聊天完成请求
|
|
completion = client.chat.completions.create(
|
|
model="qwq-plus", # 此处以 qwq-32b 为例,可按需更换模型名称
|
|
messages=[
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
# QwQ 模型仅支持流式输出方式调用
|
|
stream=True,
|
|
# 解除以下注释会在最后一个chunk返回Token使用量
|
|
# stream_options={
|
|
# "include_usage": True
|
|
# }
|
|
)
|
|
|
|
for chunk in completion:
|
|
delta = chunk.choices[0].delta
|
|
# 打印思考过程
|
|
if hasattr(delta, 'reasoning_content') and delta.reasoning_content != None:
|
|
pass
|
|
else:
|
|
# 开始回复
|
|
if delta.content != "" and is_answering is False:
|
|
is_answering = True
|
|
# 打印回复过程
|
|
# print(delta.content, end='', flush=True)
|
|
answer_content += delta.content
|
|
except Exception as e:
|
|
answer_content = prompt2
|
|
print("prompt2", prompt2)
|
|
traceback.print_exc()
|
|
return answer_content
|
|
|
|
|
|
def uuid_search(uuid):
|
|
data = {
|
|
"id": uuid
|
|
}
|
|
response = requests.post(
|
|
chatgpt_url_search,
|
|
json=data,
|
|
timeout=100000
|
|
)
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
# logger.error(
|
|
# "【{}】 Failed to get a proper response from remote "
|
|
# "server. Status Code: {}. Response: {}"
|
|
# "".format(url, response.status_code, response.text)
|
|
# )
|
|
print("Failed to get a proper response from remote "
|
|
"server. Status Code: {}. Response: {}"
|
|
"".format(response.status_code, response.text))
|
|
return {}
|
|
|
|
|
|
def uuid_search_mp(results):
|
|
results_list = [""] * len(results)
|
|
while True:
|
|
tiaochu_bool = True
|
|
|
|
for i in results_list:
|
|
if i == "":
|
|
tiaochu_bool = False
|
|
break
|
|
|
|
if tiaochu_bool == True:
|
|
break
|
|
|
|
for i in range(len(results)):
|
|
uuid = results[i]["texts"]["id"]
|
|
|
|
result = uuid_search(uuid)
|
|
if result["code"] == 200:
|
|
if result["text"] != "":
|
|
results_list[i] = result["text"]
|
|
else:
|
|
results_list[i] = "Empty character"
|
|
time.sleep(3)
|
|
return results_list
|
|
|
|
|
|
def get_multiple_urls(text_info):
|
|
input_values = []
|
|
input_index = []
|
|
input_values_yaunshi = []
|
|
|
|
for i in range(len(text_info)):
|
|
if text_info[i][3] == True:
|
|
input_values.append(text_info[i][4])
|
|
input_index.append(i)
|
|
input_values_yaunshi.append(text_info[i][0])
|
|
|
|
results_zong = []
|
|
for i in range(0, len(input_values), batch_size):
|
|
print("第{}条,总共{}条".format(str(i), str(len(input_values))))
|
|
input_values_linshi = input_values[i:i + batch_size]
|
|
input_values_yaunshi_linshi = input_values_yaunshi[i:i + batch_size]
|
|
with concurrent.futures.ThreadPoolExecutor(batch_size) as executor:
|
|
# 使用map方法并发地调用worker_function
|
|
results_1 = list(executor.map(request_api_chatgpt, input_values_linshi, input_values_yaunshi_linshi))
|
|
results_zong += results_1
|
|
time.sleep(2)
|
|
|
|
# return_list = []
|
|
# for i,j in zip(urls, results[0]):
|
|
# return_list.append([i, j])
|
|
return_dict = {}
|
|
for i, j in zip(input_index, results_zong):
|
|
return_dict[i] = j
|
|
|
|
for i in range(len(text_info)):
|
|
if i in return_dict:
|
|
text_info[i].append(return_dict[i])
|
|
else:
|
|
text_info[i].append(text_info[i][0])
|
|
return text_info
|
|
|
|
|
|
def chulipangban_test_1(snetence_id, text, sen_len):
|
|
# 引号处理
|
|
|
|
dialogs_text, dialogs_index, other_index = get_dialogs_index(text)
|
|
for dialogs_text_dan in dialogs_text:
|
|
text_dan_list = text.split(dialogs_text_dan)
|
|
text = dialogs_text_dan.join(text_dan_list)
|
|
|
|
# text_new_str = "".join(text_new)
|
|
|
|
if has_chinese(text) == False:
|
|
spilt_word = ". "
|
|
spilt_sen_len = 1e9
|
|
is_chinese = False
|
|
else:
|
|
spilt_word = "。"
|
|
spilt_sen_len = sen_len
|
|
is_chinese = True
|
|
|
|
# 存放整理完的数据
|
|
sentence_batch_list = []
|
|
|
|
if is_chinese == False:
|
|
__long_machine_en = StateMachine(long_cuter_en(max_len=25, min_len=3))
|
|
m_input = EnSequence(text)
|
|
__long_machine_en.run(m_input)
|
|
for v in m_input.sentence_list():
|
|
sentence_batch_list.append([v, snetence_id, 0])
|
|
|
|
else:
|
|
sentence_list = text.split(spilt_word)
|
|
# sentence_list_new = []
|
|
# for i in sentence_list:
|
|
# if i != "":
|
|
# sentence_list_new.append(i)
|
|
# sentence_list = sentence_list_new
|
|
sentence_batch_length = 0
|
|
|
|
for sentence in sentence_list[:-1]:
|
|
if len(sentence) < spilt_sen_len:
|
|
sentence_batch_length += len(sentence)
|
|
sentence_batch_list.append([sentence + spilt_word, snetence_id, 0])
|
|
# sentence_pre = autotitle.gen_synonyms_short(sentence)
|
|
# return_list.append(sentence_pre)
|
|
else:
|
|
sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0, sen_len)
|
|
for sentence_short in sentence_split_list[:-1]:
|
|
sentence_batch_list.append(sentence_short)
|
|
sentence_split_list[-1][0] = sentence_split_list[-1][0] + spilt_word
|
|
sentence_batch_list.append(sentence_split_list[-1])
|
|
|
|
if sentence_list[-1] != "":
|
|
if len(sentence_list[-1]) < spilt_sen_len:
|
|
sentence_batch_length += len(sentence_list[-1])
|
|
sentence_batch_list.append([sentence_list[-1], snetence_id, 0])
|
|
# sentence_pre = autotitle.gen_synonyms_short(sentence)
|
|
# return_list.append(sentence_pre)
|
|
else:
|
|
sentence_split_list = chulichangju_1(sentence_list[-1], snetence_id, [], 0, sen_len)
|
|
for sentence_short in sentence_split_list:
|
|
sentence_batch_list.append(sentence_short)
|
|
|
|
return sentence_batch_list
|
|
|
|
|
|
def paragraph_test(texts: dict):
|
|
text_new = []
|
|
for i, text in texts.items():
|
|
bool_ = has_chinese(text)
|
|
if bool_ == True:
|
|
text_list = chulipangban_test_1(i, text, sen_len=120)
|
|
text_new.extend(text_list)
|
|
else:
|
|
text_list = chulipangban_test_1(i, text, sen_len=500)
|
|
text_new.extend(text_list)
|
|
|
|
# text_new_str = "".join(text_new)
|
|
return text_new
|
|
|
|
|
|
def batch_data_process(text_list):
|
|
sentence_batch_length = 0
|
|
sentence_batch_one = []
|
|
sentence_batch_list = []
|
|
|
|
for sentence in text_list:
|
|
sentence_batch_length += len(sentence[0])
|
|
sentence_batch_one.append(sentence)
|
|
if sentence_batch_length > 500:
|
|
sentence_batch_length = 0
|
|
sentence_ = sentence_batch_one.pop(-1)
|
|
sentence_batch_list.append(sentence_batch_one)
|
|
sentence_batch_one = []
|
|
sentence_batch_one.append(sentence_)
|
|
sentence_batch_list.append(sentence_batch_one)
|
|
return sentence_batch_list
|
|
|
|
|
|
def batch_predict(batch_data_list):
|
|
'''
|
|
一个bacth数据预测
|
|
@param data_text:
|
|
@return:
|
|
'''
|
|
batch_data_list_new = []
|
|
batch_data_text_list = []
|
|
batch_data_snetence_id_list = []
|
|
for i in batch_data_list:
|
|
batch_data_text_list.append(i[0])
|
|
batch_data_snetence_id_list.append(i[1:])
|
|
# batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list)
|
|
batch_pre_data_list = batch_data_text_list
|
|
for text, sentence_id in zip(batch_pre_data_list, batch_data_snetence_id_list):
|
|
batch_data_list_new.append([text] + sentence_id)
|
|
|
|
return batch_data_list_new
|
|
|
|
|
|
def is_chinese(char):
|
|
if '\u4e00' <= char <= '\u9fff':
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def predict_data_post_processing(text_list):
|
|
text_list_sentence = []
|
|
# text_list_sentence.append([text_list[0][0], text_list[0][1]])
|
|
|
|
for i in range(len(text_list)):
|
|
if text_list[i][2] != 0:
|
|
text_list_sentence[-1][0] += text_list[i][0]
|
|
else:
|
|
text_list_sentence.append([text_list[i][0], text_list[i][1]])
|
|
|
|
return_list = {}
|
|
sentence_one = []
|
|
sentence_id = text_list_sentence[0][1]
|
|
for i in text_list_sentence:
|
|
if i[1] == sentence_id:
|
|
sentence_one.append(i[0])
|
|
else:
|
|
return_list[sentence_id] = "".join(sentence_one)
|
|
sentence_id = i[1]
|
|
sentence_one = []
|
|
sentence_one.append(i[0])
|
|
if sentence_one != []:
|
|
return_list[sentence_id] = "".join(sentence_one)
|
|
return return_list
|
|
|
|
|
|
# def main(text:list):
|
|
# # text_list = paragraph_test(text)
|
|
# # batch_data = batch_data_process(text_list)
|
|
# # text_list = []
|
|
# # for i in batch_data:
|
|
# # text_list.extend(i)
|
|
# # return_list = predict_data_post_processing(text_list)
|
|
# # return return_list
|
|
def post_sentence_ulit(text_info):
|
|
'''
|
|
后处理
|
|
:param sentence:
|
|
:return:
|
|
'''
|
|
# if len(text_list[i][0]) > 7:
|
|
# generated_text_list[i] = post_sentence_ulit(generated_text_list[i])
|
|
# else:
|
|
# generated_text_list[i] = text_list[i][0]
|
|
if_change = text_info[3]
|
|
|
|
if if_change == True:
|
|
sentence = text_info[-1].strip()
|
|
if "改写后:" in sentence:
|
|
sentence_lable_index = sentence.index("改写后:")
|
|
sentence = sentence[sentence_lable_index + 4:]
|
|
# if sentence[-1] == "\n":
|
|
# sentence = sentence[:-1]
|
|
|
|
sentence = sentence.strip("\n")
|
|
# if sentence[-1] != text_info[0][-1]:
|
|
# if is_chinese(text_info[0][-1]) == True:
|
|
# if is_chinese(sentence[-1]) != True:
|
|
# sentence = sentence[:-1]
|
|
# else:
|
|
# sentence = sentence[:-1] + text_info[0][-1]
|
|
else:
|
|
sentence = text_info[0]
|
|
return text_info[:4] + [sentence]
|
|
|
|
|
|
def has_chinese(s):
|
|
return bool(re.search('[\u4e00-\u9fa5]', s))
|
|
|
|
|
|
def english_ulit(sentence):
|
|
sentence = str(sentence).strip()
|
|
if_change = True
|
|
# 判断句子长度
|
|
if sentence != "":
|
|
if sentence[-1] == ".":
|
|
text = f"Rewrite the following sentence, ensuring the meaning remains similar but with significant modifications. The word count should be close to the original, slightly more is acceptable but not less. Directly return the rewritten sentence without any additional explanation. Do not describe the changes made, do not use parentheses to indicate modifications, and do not provide any extra commentary. Simply return the revised sentence after thinking about it.\n{sentence}\n"
|
|
else:
|
|
text = f"ewrite the following half-sentence, ensuring the meaning remains similar but with significant modifications. The word count should be close to the original, slightly more is acceptable but not less. The words at the beginning and end must connect properly with the surrounding text. Directly return the rewritten sentence without any additional explanation. Do not describe the changes made, do not use parentheses to indicate modifications, and do not provide any extra commentary. Simply return the revised sentence after thinking about it.\n{sentence}\n"
|
|
else:
|
|
if_change = False
|
|
text = f"The following words should remain unchanged:\n{sentence}\n"
|
|
|
|
return text, if_change
|
|
|
|
|
|
def chinese_ulit(sentence):
|
|
max_length = 25
|
|
sentence = str(sentence).strip().strip("\n")
|
|
if_change = True
|
|
# 判断句子长度
|
|
if len(sentence) > 9:
|
|
if sentence[-1] == "。":
|
|
# text = f"<|im_start|>user\n改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少,直接返回改写好的句子即可,不需要说其他的内容,不需要太多的思考过程,在思考结束后只返回修改后的句子 :\n{sentence}<|im_end|>\n<|im_start|>assistant\n"
|
|
text = f"改写下面这句话,要求意思接近但是改动幅度比较大,字数要求与原来的相近,可以稍微多一点但是不能少,直接返回改写好的句子即可,不需要说其他的内容,不需要太多的思考过程,强制要求在思考结束后只返回修改后的句子,不需要解释句子是怎么改的,也不需要括弧修改细节,只要返回改好之后的句子 :\n{sentence}\n"
|
|
else:
|
|
# text = f"<|im_start|>user\n改写下面半这句话,要求意思接近但是改动幅度比较大,字数只能多不能少,短句前后词跟上下句衔接不能有错误,直接返回改写好的句子即可,不需要说其他的内容,不需要太多的思考过程,在思考结束后只返回修改后的句子 :\n{sentence}<|im_end|>\n<|im_start|>assistant\n"
|
|
text = f"改写下面半这句话,要求意思接近但是改动幅度比较大,字数要求与原来的相近,可以稍微多一点但是不能少,短句前后词跟上下句衔接不能有错误,直接返回改写好的句子即可,不需要说其他的内容,不需要太多的思考过程,强制要求在思考结束后只返回修改后的句子,不需要解释句子是怎么改的,也不需要括弧修改细节,只要返回改好之后的句子 :\n{sentence}\n"
|
|
else:
|
|
# text = f"<|im_start|>user\n下面词不做任何变化:\n{sentence}<|im_end|>\n<|im_start|>assistant\n"
|
|
text = f"下面词不做任何变化:\n{sentence}\n"
|
|
if_change = False
|
|
return text, if_change
|
|
|
|
# 判断标题
|
|
result_biaoti_list_0 = re.findall(pantten_biaoti_0, sentence)
|
|
result_biaoti_list_1 = re.findall(pantten_biaoti_1, sentence)
|
|
result_biaoti_list_2 = re.findall(pantten_biaoti_2, sentence)
|
|
result_biaoti_list_3 = re.findall(pantten_biaoti_3, sentence)
|
|
result_biaoti_list_4 = re.findall(pantten_biaoti_4, sentence)
|
|
result_biaoti_list_5 = re.findall(pantten_biaoti_5, sentence)
|
|
|
|
if list(set(result_biaoti_list_0 + result_biaoti_list_1 + result_biaoti_list_2 + result_biaoti_list_3)) != []:
|
|
if len(sentence) < max_length:
|
|
if_change = False
|
|
return text, if_change
|
|
|
|
elif list(set(result_biaoti_list_4 + result_biaoti_list_5)) != []:
|
|
if_change = False
|
|
return text, if_change
|
|
|
|
return text, if_change
|
|
|
|
|
|
def pre_sentence_ulit(sentence):
|
|
'''
|
|
预处理
|
|
:param sentence:
|
|
:return:
|
|
'''
|
|
# 判断是否为全英文
|
|
if has_chinese(sentence) == False:
|
|
text, if_change = english_ulit(sentence)
|
|
else:
|
|
text, if_change = chinese_ulit(sentence)
|
|
return text, if_change
|
|
|
|
|
|
def main(texts: dict):
|
|
text_list = paragraph_test(texts)
|
|
|
|
text_info = []
|
|
text_sentence = []
|
|
text_list_new = []
|
|
|
|
# for i in text_list:
|
|
# pre = one_predict(i)
|
|
# text_list_new.append(pre)
|
|
|
|
# vllm预测
|
|
for i in text_list:
|
|
print("sen", i[0])
|
|
text, if_change = pre_sentence_ulit(i[0])
|
|
text_info.append([i[0], i[1], i[2], if_change, text])
|
|
|
|
text_info = get_multiple_urls(text_info)
|
|
|
|
for i in range(len(text_info)):
|
|
text_info[i] = post_sentence_ulit(text_info[i])
|
|
|
|
for i in range(len(text_info)):
|
|
text_list_new.append([text_info[i][-1]] + text_info[i][1:3])
|
|
|
|
return_list = predict_data_post_processing(text_list_new)
|
|
return return_list
|
|
|
|
|
|
def remove_specific_parentheses_1(text):
|
|
result = []
|
|
buffer = []
|
|
bracket_level = 0
|
|
in_annotation = False
|
|
|
|
i = 0
|
|
while i < len(text):
|
|
char = text[i]
|
|
|
|
# 检测到左括号开始解析
|
|
if char == '(' and bracket_level == 0:
|
|
bracket_level = 1
|
|
buffer = [char]
|
|
i += 1
|
|
while i < len(text) and bracket_level > 0:
|
|
next_char = text[i]
|
|
if next_char == '(':
|
|
bracket_level += 1
|
|
elif next_char == ')':
|
|
bracket_level -= 1
|
|
buffer.append(next_char)
|
|
i += 1
|
|
|
|
# 判断是否包含需要排除的标记
|
|
bracket_content = ''.join(buffer)
|
|
if any(mark in bracket_content for mark in ('说明:', '注:', '解析:', '修改说明:')):
|
|
pass # 跳过需要排除的内容
|
|
else:
|
|
result.append(bracket_content)
|
|
else:
|
|
result.append(char)
|
|
i += 1
|
|
|
|
return ''.join(result)
|
|
|
|
|
|
def remove_specific_parentheses_2(text):
|
|
pattern = r'解析改写要点:.*'
|
|
cleaned_text = re.sub(pattern, '', text, flags=re.DOTALL)
|
|
return cleaned_text
|
|
|
|
|
|
def classify(): # 调用模型,设置最大batch_size
|
|
while True:
|
|
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
|
|
time.sleep(3)
|
|
continue
|
|
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text
|
|
data_dict_path = json.loads(query)
|
|
path = data_dict_path['path']
|
|
# text_type = data_dict["text_type"]
|
|
|
|
with open(path, encoding='utf8') as f1:
|
|
# 加载文件的对象
|
|
data_dict = json.load(f1)
|
|
t = time.localtime(time.time())
|
|
print(t)
|
|
query_id = data_dict['id']
|
|
texts = data_dict["text"]
|
|
text_type = data_dict["text_type"]
|
|
|
|
assert text_type in ['focus', 'chapter']
|
|
if text_type == 'focus':
|
|
texts_list = main(texts)
|
|
elif text_type == 'chapter':
|
|
# try:
|
|
texts_list = main(texts)
|
|
# except:
|
|
# texts_list = []
|
|
else:
|
|
texts_list = []
|
|
if texts_list != []:
|
|
|
|
data_1 = {}
|
|
for i in texts_list:
|
|
text = texts_list[i]
|
|
text = remove_specific_parentheses_1(text)
|
|
text = remove_specific_parentheses_2(text)
|
|
data_1[i] = text
|
|
texts_list = data_1
|
|
return_text = {"texts": texts_list, "probabilities": None, "status_code": 200}
|
|
else:
|
|
smtp_f("drop_weight_rewrite_increase")
|
|
return_text = {"texts": texts_list, "probabilities": None, "status_code": 400}
|
|
|
|
|
|
|
|
load_result_path = "./new_data_logs/{}.json".format(query_id)
|
|
|
|
print("query_id: ", query_id)
|
|
print("load_result_path: ", load_result_path)
|
|
|
|
with open(load_result_path, 'w', encoding='utf8') as f2:
|
|
# ensure_ascii=False才能输入中文,否则是Unicode字符
|
|
# indent=2 JSON数据的缩进,美观
|
|
json.dump(return_text, f2, ensure_ascii=False, indent=4)
|
|
debug_id_1 = 1
|
|
redis_.set(query_id, load_result_path, 86400)
|
|
debug_id_2 = 2
|
|
redis_.srem(db_key_querying, query_id)
|
|
debug_id_3 = 3
|
|
log.log('start at',
|
|
'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format(
|
|
query_id, load_result_path, return_text, debug_id_1, debug_id_2, debug_id_3))
|
|
|
|
|
|
@app.route("/predict", methods=["POST"])
|
|
def handle_query():
|
|
print(request.remote_addr)
|
|
texts = request.json["texts"]
|
|
text_type = request.json["text_type"]
|
|
if texts is None:
|
|
return_text = {"texts": "输入了空值", "probabilities": None, "status_code": 402}
|
|
return jsonify(return_text)
|
|
if isinstance(texts, dict):
|
|
id_ = str(uuid.uuid1()) # 为query生成唯一标识
|
|
print("uuid: ", id_)
|
|
d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id
|
|
|
|
load_request_path = './request_data_logs/{}.json'.format(id_)
|
|
with open(load_request_path, 'w', encoding='utf8') as f2:
|
|
# ensure_ascii=False才能输入中文,否则是Unicode字符
|
|
# indent=2 JSON数据的缩进,美观
|
|
json.dump(d, f2, ensure_ascii=False, indent=4)
|
|
redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis
|
|
redis_.sadd(db_key_querying, id_)
|
|
return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200}
|
|
print("ok")
|
|
else:
|
|
return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 401}
|
|
return jsonify(return_text) # 返回结果
|
|
|
|
|
|
# @app.route("/predict", methods=["POST"])
|
|
# def handle_query():
|
|
# print(request.remote_addr)
|
|
# texts = request.json["texts"]
|
|
# text_type = request.json["text_type"]
|
|
# if isinstance(texts, dict):
|
|
# id_ = str(uuid.uuid1()) # 为query生成唯一标识
|
|
# return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200}
|
|
# print("ok")
|
|
# else:
|
|
# return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 401}
|
|
# return jsonify(return_text) # 返回结果
|
|
|
|
|
|
t = Thread(target=classify)
|
|
t.start()
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
|
|
filename='rewrite.log',
|
|
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
|
|
# a是追加模式,默认如果不写的话,就是追加模式
|
|
format=
|
|
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
|
# 日志格式
|
|
)
|
|
app.run(host="0.0.0.0", port=14012, threaded=True, debug=False)
|
|
|