rag知识库问答
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

324 lines
9.7 KiB

5 months ago
# 这是一个示例 Python 脚本。
# 按 Shift+F10 执行或将其替换为您的代码。
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
import os
import faiss
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import requests
import time
from flask import Flask, jsonify, Response, request
from flask_cors import CORS
import pandas as pd
import redis
from openai import OpenAI
5 months ago
import asyncio
import websockets
import json
import ssl
import pathlib
5 months ago
app = Flask(__name__)
CORS(app)
app.config["JSON_AS_ASCII"] = False
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=1, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query'
db_key_querying = 'querying'
batch_size = 32
openai_api_key = "token-abc123"
openai_api_base = "http://127.0.0.1:12011/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
# model = "1"
model_encode = SentenceTransformer('/home/majiahui/project/models-llm/bge-large-zh-v1.5')
propmt_connect = '''我是一名中医,你是一个中医的医生的助理,我的患者有一个症状,症状如下:
{}
根据这些症状我通过查找资料{}
请根据上面的这些资料和方子并根据每篇文章的转发数确定文章的重要程度转发数越高的文章最终答案的参考度越高反之越低根据患者的症状和上面的文章的资料的重要程度以及文章和症状的匹配程度帮我开出正确的药方和治疗方案'''
propmt_connect_ziliao = '''在“{}”资料中,有如下相关内容:
{}'''
def dialog_line_parse(text):
"""
将数据输入模型进行分析并输出结果
:param url: 模型url
:param text: 进入模型的数据
:return: 模型返回结果
"""
url_predict = "http://118.178.228.101:12004/predict"
response = requests.post(
url_predict,
json=text,
timeout=100000
)
if response.status_code == 200:
return response.json()
else:
# logger.error(
# "【{}】 Failed to get a proper response from remote "
# "server. Status Code: {}. Response: {}"
# "".format(url, response.status_code, response.text)
# )
print("{}】 Failed to get a proper response from remote "
"server. Status Code: {}. Response: {}"
"".format(url_predict, response.status_code, response.text))
return {}
# ['choices'][0]['message']['content']
#
# text = text['messages'][0]['content']
# return_text = {
# 'code': 200,
# 'id': "1",
# 'object': 0,
# 'created': 0,
# 'model': 0,
# 'choices': [
# {
# 'index': 0,
# 'message': {
# 'role': 'assistant',
# 'content': text
# },
# 'logprobs': None,
# 'finish_reason': 'stop'
# }
# ],
# 'usage': 0,
# 'system_fingerprint': 0
# }
# return return_text
def shengcehng_array(data):
embs = model_encode.encode(data, normalize_embeddings=True)
return embs
def main(question, title, top):
db_dict = {
"1": "yetianshi"
}
'''
定义文件路径
'''
'''
加载文件
'''
'''
文本分割
'''
'''
构建向量数据库
1. 正常匹配
2. 把文本使用大模型生成一个问题之后再进行匹配
'''
'''
根据提问匹配上下文
'''
d = 1024
db_type_list = title.split(",")
paper_list_str = ""
for title_dan in db_type_list:
embs = shengcehng_array([question])
index = faiss.IndexFlatIP(d) # buid the index
# 查找向量
# vector_path = f"data_np/{title_dan}.npy"
# vectors = np.load(vector_path)
5 months ago
data_str = pd.read_csv(f"data_file_res/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist()
data_str_valid = []
for i in data_str:
if i[3] == True:
data_str_valid.append(i)
data_str_vectors_list = []
for i in data_str_valid:
data_str_vectors_list.append(eval(i[-1]))
vectors = np.array(data_str_vectors_list)
5 months ago
index.add(vectors)
D, I = index.search(embs, int(top))
print(I)
reference_list = []
for i,j in zip(I[0], D[0]):
reference_list.append([data_str_valid[i], j])
5 months ago
for i,j in enumerate(reference_list):
paper_list_str += "{}\n{},此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][1], j[1])
5 months ago
'''
构造prompt
'''
print("paper_list_str", paper_list_str)
propmt_connect_ziliao_input = []
for i in db_type_list:
propmt_connect_ziliao_input.append(propmt_connect_ziliao.format(i, paper_list_str))
propmt_connect_ziliao_input_str = "".join(propmt_connect_ziliao_input)
propmt_connect_input = propmt_connect.format(question, propmt_connect_ziliao_input_str)
'''
生成回答
'''
return model_generate_stream(propmt_connect_input)
def model_generate_stream(prompt):
messages = [
{"role": "user", "content": prompt}
]
stream = client.chat.completions.create(model=model,
messages=messages,
stream=True)
printed_reasoning_content = False
printed_content = False
for chunk in stream:
reasoning_content = None
content = None
# Check the content is reasoning_content or content
if hasattr(chunk.choices[0].delta, "reasoning_content"):
reasoning_content = chunk.choices[0].delta.reasoning_content
elif hasattr(chunk.choices[0].delta, "content"):
content = chunk.choices[0].delta.content
if reasoning_content is not None:
if not printed_reasoning_content:
printed_reasoning_content = True
print("reasoning_content:", end="", flush=True)
print(reasoning_content, end="", flush=True)
elif content is not None:
if not printed_content:
printed_content = True
print("\ncontent:", end="", flush=True)
# Extract and print the content
# print(content, end="", flush=True)
print(content)
yield content
async def handle_websocket(websocket):
print("客户端已连接")
try:
async for message in websocket:
try:
data = json.loads(message)
texts = data.get("texts")
title = data.get("title")
top = data.get("top")
print(f"收到消息: {texts}")
# 获取响应
response = main(texts, title, top)
# 发送响应
for char in response:
await websocket.send(char)
await asyncio.sleep(0.001) # 小延迟避免发送过快
# 发送完成标记
await websocket.send("[DONE]")
print("消息发送完成")
except json.JSONDecodeError:
await websocket.send('{"error": "Invalid JSON format"}')
except Exception as e:
print(f"处理消息时发生错误: {e}")
await websocket.send('{"error": "Internal server error"}')
5 months ago
except websockets.exceptions.ConnectionClosed:
print("客户端断开连接")
except Exception as e:
print(f"WebSocket处理异常: {e}")
finally:
print("连接处理结束")
5 months ago
async def main_api():
try:
ssl_context = None
# 检查证书文件是否存在
ssl_cert = "server.crt"
ssl_key = "server.key"
if os.path.exists(ssl_cert) and os.path.exists(ssl_key):
try:
# 创建SSL上下文
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# 加载证书链
ssl_context.load_cert_chain(ssl_cert, ssl_key)
# 禁用主机名验证(对于自签名证书)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
print("SSL证书已加载,使用WSS协议")
except Exception as e:
print(f"SSL证书加载失败: {e}")
print("将使用WS协议")
ssl_context = None
else:
print("警告:SSL证书文件未找到,将使用WS协议")
ssl_context = None
# 创建服务器
server = await websockets.serve(
handle_websocket,
"0.0.0.0",
27001,
ssl=ssl_context,
ping_interval=30, # 添加ping间隔
ping_timeout=30, # 添加ping超时
close_timeout=30 # 添加关闭超时
)
if ssl_context:
print("WSS服务器已启动: wss://0.0.0.0:27001")
else:
print("WS服务器已启动: ws://0.0.0.0:27001")
# 保持服务器运行
await server.wait_closed()
except Exception as e:
print(f"服务器启动失败: {e}")
import traceback
traceback.print_exc()
5 months ago
if __name__ == "__main__":
# 设置更详细的事件循环调试
import logging
logging.basicConfig(level=logging.INFO)
try:
asyncio.run(main_api())
except KeyboardInterrupt:
print("服务器被用户中断")
except Exception as e:
print(f"服务器运行错误: {e}")
import traceback
traceback.print_exc()