rag知识库问答
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

226 lines
6.5 KiB

2 months ago
# 这是一个示例 Python 脚本。
# 按 Shift+F10 执行或将其替换为您的代码。
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
import faiss
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import requests
import time
from flask import Flask, jsonify
from flask import request
2 months ago
import pandas as pd
2 months ago
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
model = SentenceTransformer('/home/majiahui/project/models-llm/bge-large-zh-v1.5')
propmt_connect = '''我是一名中医,你是一个中医的医生的助理,我的患者有一个症状,症状如下:
{}
根据这些症状我通过查找资料{}
2 months ago
请根据上面的这些资料和方子并根据每篇文章的转发数确定文章的重要程度转发数越高的文章最终答案的参考度越高反之越低根据患者的症状和上面的文章的资料的重要程度以及文章和症状的匹配程度帮我开出正确的药方和治疗方案'''
2 months ago
propmt_connect_ziliao = '''在“{}”资料中,有如下相关内容:
{}'''
def dialog_line_parse(url, text):
"""
将数据输入模型进行分析并输出结果
:param url: 模型url
:param text: 进入模型的数据
:return: 模型返回结果
"""
response = requests.post(
url,
json=text,
timeout=100000
)
if response.status_code == 200:
return response.json()
else:
# logger.error(
# "【{}】 Failed to get a proper response from remote "
# "server. Status Code: {}. Response: {}"
# "".format(url, response.status_code, response.text)
# )
print("{}】 Failed to get a proper response from remote "
"server. Status Code: {}. Response: {}"
"".format(url, response.status_code, response.text))
return {}
def shengcehng_array(data):
embs = model.encode(data, normalize_embeddings=True)
return embs
2 months ago
def Building_vector_database(type, name, df):
2 months ago
data_ndarray = np.empty((0, 1024))
2 months ago
for sen in df:
data_ndarray = np.concatenate((data_ndarray, shengcehng_array([sen[0]])))
2 months ago
print("data_ndarray.shape", data_ndarray.shape)
print("data_ndarray.shape", data_ndarray.shape)
np.save(f'data_np/{name}.npy', data_ndarray)
def ulit_request_file(file, title):
file_name = file.filename
2 months ago
file_name_save = "data_file/{}.csv".format(title)
2 months ago
file.save(file_name_save)
2 months ago
# try:
# with open(file_name_save, encoding="gbk") as f:
# content = f.read()
# except:
# with open(file_name_save, encoding="utf-8") as f:
# content = f.read()
2 months ago
# elif file_name.split(".")[-1] == "docx":
# content = docx2txt.process(file_name_save)
2 months ago
# content_list = [i for i in content.split("\n")]
df = pd.read_csv(file_name_save, sep="\t", encoding="utf-8").values.tolist()
2 months ago
2 months ago
return df
2 months ago
def main(question, db_type, top):
db_dict = {
"1": "yetianshi"
}
'''
定义文件路径
'''
'''
加载文件
'''
'''
文本分割
'''
'''
构建向量数据库
1. 正常匹配
2. 把文本使用大模型生成一个问题之后再进行匹配
'''
'''
根据提问匹配上下文
'''
d = 1024
db_type_list = db_type.split(",")
paper_list_str = ""
for i in db_type_list:
embs = shengcehng_array([question])
2 months ago
index = faiss.IndexFlatIP(d) # buid the index
2 months ago
data_np = np.load(f"data_np/{i}.npy")
2 months ago
# data_str = open(f"data_file/{i}.txt").read().split("\n")
data_str = pd.read_csv(f"data_file/{i}.csv", sep="\t", encoding="utf-8").values.tolist()
2 months ago
index.add(data_np)
D, I = index.search(embs, int(top))
print(I)
reference_list = []
2 months ago
for i,j in zip(I[0], D[0]):
reference_list.append([data_str[i], j])
2 months ago
for i,j in enumerate(reference_list):
2 months ago
paper_list_str += "{}\n{},此篇文章的转发数为{},评论数为{},点赞数为{}\n,此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][0], j[0][1], j[0][2], j[0][3], j[1])
2 months ago
'''
构造prompt
'''
print("paper_list_str", paper_list_str)
propmt_connect_ziliao_input = []
for i in db_type_list:
propmt_connect_ziliao_input.append(propmt_connect_ziliao.format(i, paper_list_str))
propmt_connect_ziliao_input_str = "".join(propmt_connect_ziliao_input)
propmt_connect_input = propmt_connect.format(question, propmt_connect_ziliao_input_str)
'''
生成回答
'''
url_predict = "http://192.168.31.74:26000/predict"
url_search = "http://192.168.31.74:26000/search"
# data = {
# "content": propmt_connect_input
# }
data = {
"content": propmt_connect_input,
"model": "qwq-32",
"top_p": 0.9,
"temperature": 0.6
}
res = dialog_line_parse(url_predict, data)
id_ = res["texts"]["id"]
data = {
"id": id_
}
while True:
res = dialog_line_parse(url_search, data)
if res["code"] == 200:
break
else:
time.sleep(1)
spilt_str = "</think>"
think, response = str(res["text"]).split(spilt_str)
return think, response
@app.route("/upload_file", methods=["POST"])
def upload_file():
print(request.remote_addr)
file = request.files.get('file')
title = request.form.get("title")
2 months ago
df = ulit_request_file(file, title)
Building_vector_database("1", title, df)
return_json = {
"code": 200,
"info": "上传完成"
}
return jsonify(return_json) # 返回结果
@app.route("/upload_file_check", methods=["POST"])
def upload_file_check():
print(request.remote_addr)
file = request.files.get('file')
title = request.form.get("title")
df = ulit_request_file(file, title)
Building_vector_database("1", title, df)
2 months ago
return_json = {
"code": 200,
"info": "上传完成"
}
return jsonify(return_json) # 返回结果
@app.route("/search", methods=["POST"])
def search():
print(request.remote_addr)
texts = request.json["texts"]
text_type = request.json["text_type"]
top = request.json["top"]
think, response = main(texts, text_type, top)
return_json = {
"code": 200,
"think": think,
"response": response
}
return jsonify(return_json) # 返回结果
if __name__ == "__main__":
app.run(host="0.0.0.0", port=27000, threaded=True, debug=False)