You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
207 lines
5.5 KiB
207 lines
5.5 KiB
# 这是一个示例 Python 脚本。
|
|
|
|
# 按 Shift+F10 执行或将其替换为您的代码。
|
|
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
|
|
import faiss
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from sentence_transformers import SentenceTransformer
|
|
import requests
|
|
import time
|
|
from flask import Flask, jsonify
|
|
from flask import request
|
|
|
|
app = Flask(__name__)
|
|
app.config["JSON_AS_ASCII"] = False
|
|
|
|
model = SentenceTransformer('/home/majiahui/project/models-llm/bge-large-zh-v1.5')
|
|
propmt_connect = '''我是一名中医,你是一个中医的医生的助理,我的患者有一个症状,症状如下:
|
|
{}
|
|
根据这些症状,我通过查找资料,{}
|
|
请根据上面的这些资料和方子,根据患者的症状帮我开出正确的药方和治疗方案'''
|
|
|
|
propmt_connect_ziliao = '''在“{}”资料中,有如下相关内容:
|
|
{}'''
|
|
|
|
|
|
def dialog_line_parse(url, text):
|
|
"""
|
|
将数据输入模型进行分析并输出结果
|
|
:param url: 模型url
|
|
:param text: 进入模型的数据
|
|
:return: 模型返回结果
|
|
"""
|
|
|
|
response = requests.post(
|
|
url,
|
|
json=text,
|
|
timeout=100000
|
|
)
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
else:
|
|
# logger.error(
|
|
# "【{}】 Failed to get a proper response from remote "
|
|
# "server. Status Code: {}. Response: {}"
|
|
# "".format(url, response.status_code, response.text)
|
|
# )
|
|
print("【{}】 Failed to get a proper response from remote "
|
|
"server. Status Code: {}. Response: {}"
|
|
"".format(url, response.status_code, response.text))
|
|
return {}
|
|
|
|
|
|
def shengcehng_array(data):
|
|
embs = model.encode(data, normalize_embeddings=True)
|
|
return embs
|
|
|
|
def Building_vector_database(type, name, data):
|
|
data_ndarray = np.empty((0, 1024))
|
|
for sen in data:
|
|
data_ndarray = np.concatenate((data_ndarray, shengcehng_array([sen])))
|
|
print("data_ndarray.shape", data_ndarray.shape)
|
|
|
|
print("data_ndarray.shape", data_ndarray.shape)
|
|
np.save(f'data_np/{name}.npy', data_ndarray)
|
|
|
|
|
|
|
|
def ulit_request_file(file, title):
|
|
file_name = file.filename
|
|
file_name_save = "data_file/{}.txt".format(title)
|
|
file.save(file_name_save)
|
|
|
|
try:
|
|
with open(file_name_save, encoding="gbk") as f:
|
|
content = f.read()
|
|
except:
|
|
with open(file_name_save, encoding="utf-8") as f:
|
|
content = f.read()
|
|
# elif file_name.split(".")[-1] == "docx":
|
|
# content = docx2txt.process(file_name_save)
|
|
|
|
content_list = [i for i in content.split("\n")]
|
|
|
|
return content_list
|
|
|
|
|
|
def main(question, db_type, top):
|
|
db_dict = {
|
|
"1": "yetianshi"
|
|
}
|
|
'''
|
|
定义文件路径
|
|
'''
|
|
|
|
'''
|
|
加载文件
|
|
'''
|
|
|
|
'''
|
|
文本分割
|
|
'''
|
|
|
|
'''
|
|
构建向量数据库
|
|
1. 正常匹配
|
|
2. 把文本使用大模型生成一个问题之后再进行匹配
|
|
'''
|
|
|
|
'''
|
|
根据提问匹配上下文
|
|
'''
|
|
d = 1024
|
|
db_type_list = db_type.split(",")
|
|
|
|
paper_list_str = ""
|
|
for i in db_type_list:
|
|
embs = shengcehng_array([question])
|
|
index = faiss.IndexFlatL2(d) # buid the index
|
|
data_np = np.load(f"data_np/{i}.npy")
|
|
data_str = open(f"data_file/{i}.txt").read().split("\n")
|
|
index.add(data_np)
|
|
D, I = index.search(embs, int(top))
|
|
print(I)
|
|
|
|
reference_list = []
|
|
for i in I[0]:
|
|
reference_list.append(data_str[i])
|
|
|
|
for i,j in enumerate(reference_list):
|
|
paper_list_str += "第{}篇\n{}\n".format(str(i+1), j)
|
|
|
|
'''
|
|
构造prompt
|
|
'''
|
|
print("paper_list_str", paper_list_str)
|
|
propmt_connect_ziliao_input = []
|
|
for i in db_type_list:
|
|
propmt_connect_ziliao_input.append(propmt_connect_ziliao.format(i, paper_list_str))
|
|
|
|
propmt_connect_ziliao_input_str = ",".join(propmt_connect_ziliao_input)
|
|
propmt_connect_input = propmt_connect.format(question, propmt_connect_ziliao_input_str)
|
|
'''
|
|
生成回答
|
|
'''
|
|
url_predict = "http://192.168.31.74:26000/predict"
|
|
url_search = "http://192.168.31.74:26000/search"
|
|
|
|
# data = {
|
|
# "content": propmt_connect_input
|
|
# }
|
|
data = {
|
|
"content": propmt_connect_input,
|
|
"model": "qwq-32",
|
|
"top_p": 0.9,
|
|
"temperature": 0.6
|
|
}
|
|
|
|
res = dialog_line_parse(url_predict, data)
|
|
id_ = res["texts"]["id"]
|
|
|
|
data = {
|
|
"id": id_
|
|
}
|
|
|
|
while True:
|
|
res = dialog_line_parse(url_search, data)
|
|
if res["code"] == 200:
|
|
break
|
|
else:
|
|
time.sleep(1)
|
|
spilt_str = "</think>"
|
|
think, response = str(res["text"]).split(spilt_str)
|
|
return think, response
|
|
|
|
|
|
@app.route("/upload_file", methods=["POST"])
|
|
def upload_file():
|
|
print(request.remote_addr)
|
|
file = request.files.get('file')
|
|
title = request.form.get("title")
|
|
data = ulit_request_file(file, title)
|
|
Building_vector_database("1", title, data)
|
|
return_json = {
|
|
"code": 200,
|
|
"info": "上传完成"
|
|
}
|
|
return jsonify(return_json) # 返回结果
|
|
|
|
|
|
@app.route("/search", methods=["POST"])
|
|
def search():
|
|
print(request.remote_addr)
|
|
texts = request.json["texts"]
|
|
text_type = request.json["text_type"]
|
|
top = request.json["top"]
|
|
think, response = main(texts, text_type, top)
|
|
return_json = {
|
|
"code": 200,
|
|
"think": think,
|
|
"response": response
|
|
}
|
|
return jsonify(return_json) # 返回结果
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(host="0.0.0.0", port=27000, threaded=True, debug=False)
|
|
|