|
|
|
# coding:utf-8
|
|
|
|
import os
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
import torch
|
|
|
|
from transformers import (
|
|
|
|
AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding,
|
|
|
|
Trainer, TrainingArguments
|
|
|
|
)
|
|
|
|
from flask import Flask, jsonify
|
|
|
|
from flask import request
|
|
|
|
import uuid
|
|
|
|
app = Flask(__name__)
|
|
|
|
app.config["JSON_AS_ASCII"] = False
|
|
|
|
from threading import Thread
|
|
|
|
import redis
|
|
|
|
import uuid
|
|
|
|
import time
|
|
|
|
import json
|
|
|
|
import docx2txt
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
|
|
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=12, password="zhicheng123*")
|
|
|
|
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
|
|
|
|
|
|
|
|
db_key_query = 'query'
|
|
|
|
db_key_querying = 'querying'
|
|
|
|
db_key_queryset = 'queryset'
|
|
|
|
batch_size = 32
|
|
|
|
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”")
|
|
|
|
|
|
|
|
def get_dialogs_index(line: str):
|
|
|
|
"""
|
|
|
|
获取对话及其索引
|
|
|
|
:param line 文本
|
|
|
|
:return dialogs 对话内容
|
|
|
|
dialogs_index: 对话位置索引
|
|
|
|
other_index: 其他内容位置索引
|
|
|
|
"""
|
|
|
|
dialogs = re.finditer(RE_DIALOG, line)
|
|
|
|
dialogs_text = re.findall(RE_DIALOG, line)
|
|
|
|
dialogs_index = []
|
|
|
|
for dialog in dialogs:
|
|
|
|
all_ = [i for i in range(dialog.start(), dialog.end())]
|
|
|
|
dialogs_index.extend(all_)
|
|
|
|
other_index = [i for i in range(len(line)) if i not in dialogs_index]
|
|
|
|
|
|
|
|
return dialogs_text, dialogs_index, other_index
|
|
|
|
|
|
|
|
|
|
|
|
def chulichangju_1(text, chulipangban_return_list):
|
|
|
|
fuhao = ["。"]
|
|
|
|
dialogs_text, dialogs_index, other_index = get_dialogs_index(text)
|
|
|
|
text_1 = text[:500]
|
|
|
|
text_2 = text[500:]
|
|
|
|
text_1_new = ""
|
|
|
|
if text_2 == "":
|
|
|
|
chulipangban_return_list.append(text_1)
|
|
|
|
return chulipangban_return_list
|
|
|
|
for i in range(len(text_1) - 1, -1, -1):
|
|
|
|
if text_1[i] in fuhao:
|
|
|
|
if i in dialogs_index:
|
|
|
|
continue
|
|
|
|
text_1_new = text_1[:i]
|
|
|
|
text_1_new += text_1[i]
|
|
|
|
chulipangban_return_list.append(text_1_new)
|
|
|
|
if text_2 != "":
|
|
|
|
if i + 1 != 500:
|
|
|
|
text_2 = text_1[i + 1:] + text_2
|
|
|
|
break
|
|
|
|
# else:
|
|
|
|
# chulipangban_return_list.append(text_1)
|
|
|
|
if text_1_new == "":
|
|
|
|
chulipangban_return_list.append(text_1)
|
|
|
|
if text_2 != "":
|
|
|
|
chulipangban_return_list = chulichangju_1(text_2, chulipangban_return_list)
|
|
|
|
return chulipangban_return_list
|
|
|
|
|
|
|
|
def ulit_request_file(file):
|
|
|
|
file_name = file.filename
|
|
|
|
file_name_save = "data/request/{}".format(file_name)
|
|
|
|
file.save(file_name_save)
|
|
|
|
|
|
|
|
if file_name.split(".")[-1] == "txt":
|
|
|
|
try:
|
|
|
|
with open(file_name_save, encoding="gbk") as f:
|
|
|
|
content = f.read()
|
|
|
|
except:
|
|
|
|
with open(file_name_save, encoding="utf-8") as f:
|
|
|
|
content = f.read()
|
|
|
|
# elif file_name.split(".")[-1] == "docx":
|
|
|
|
# content = docx2txt.process(file_name_save)
|
|
|
|
|
|
|
|
content_list = [i for i in content.split("\n")]
|
|
|
|
print(content_list)
|
|
|
|
|
|
|
|
content_list_new = []
|
|
|
|
for sen in content_list:
|
|
|
|
if len(sen) < 500:
|
|
|
|
content_list_new.append(sen)
|
|
|
|
else:
|
|
|
|
content_list_new.extend(chulichangju_1(sen, []))
|
|
|
|
|
|
|
|
return content_list
|
|
|
|
|
|
|
|
|
|
|
|
@app.route("/predict", methods=["POST"])
|
|
|
|
def handle_query_predict():
|
|
|
|
print(request.remote_addr)
|
|
|
|
|
|
|
|
# request.form.get('prompt')
|
|
|
|
dataBases = ""
|
|
|
|
minSimilarity = "" # txt
|
|
|
|
minWords = ""
|
|
|
|
title = request.form.get("title")
|
|
|
|
author = request.form.get("author") # txt
|
|
|
|
file = request.files.get('file')
|
|
|
|
token = ""
|
|
|
|
account = ""
|
|
|
|
goodsId = ""
|
|
|
|
callbackUrl = ""
|
|
|
|
content_list = ulit_request_file(file)
|
|
|
|
|
|
|
|
id_ = str(uuid.uuid1()) # 为query生成唯一标识
|
|
|
|
id_ = id_.upper()
|
|
|
|
print("uuid: ", id_)
|
|
|
|
print(id_)
|
|
|
|
d = {
|
|
|
|
'id': id_,
|
|
|
|
'dataBases': dataBases,
|
|
|
|
'minSimilarity': minSimilarity,
|
|
|
|
'minWords': minWords,
|
|
|
|
'title': title,
|
|
|
|
'author': author,
|
|
|
|
'content_list': content_list,
|
|
|
|
'token': token,
|
|
|
|
'account': account,
|
|
|
|
'goodsId': goodsId,
|
|
|
|
'callbackUrl': callbackUrl
|
|
|
|
}
|
|
|
|
print(d)
|
|
|
|
# 绑定文本和query id
|
|
|
|
# recall_10(id_, title, abst_zh, content)
|
|
|
|
|
|
|
|
load_request_path = './request_data_logs/{}.json'.format(id_)
|
|
|
|
with open(load_request_path, 'w', encoding='utf8') as f2: # ensure_ascii=False才能输入中文,否则是Unicode字符 indent=2 JSON数据的缩进,美观
|
|
|
|
json.dump(d, f2, ensure_ascii=False, indent=4)
|
|
|
|
redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis
|
|
|
|
|
|
|
|
return_text = {
|
|
|
|
'code': 0,
|
|
|
|
'msg': "请求成功",
|
|
|
|
'data': {
|
|
|
|
'balances': "",
|
|
|
|
'orderId': id_,
|
|
|
|
'consumeNum': ""
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return jsonify(return_text) # 返回结果
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
app.run(host="0.0.0.0", port=16005, threaded=True, debug=False)
|