|
@ -10,15 +10,15 @@ from sentence_transformers import SentenceTransformer |
|
|
import requests |
|
|
import requests |
|
|
import time |
|
|
import time |
|
|
from flask import Flask, jsonify, Response, request |
|
|
from flask import Flask, jsonify, Response, request |
|
|
from openai import OpenAI |
|
|
|
|
|
from flask_cors import CORS |
|
|
from flask_cors import CORS |
|
|
import pandas as pd |
|
|
import pandas as pd |
|
|
import concurrent.futures |
|
|
|
|
|
import json |
|
|
|
|
|
from threading import Thread |
|
|
|
|
|
import redis |
|
|
import redis |
|
|
|
|
|
from openai import OpenAI |
|
|
import asyncio |
|
|
import asyncio |
|
|
import websockets |
|
|
import websockets |
|
|
|
|
|
import json |
|
|
|
|
|
import ssl |
|
|
|
|
|
import pathlib |
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
app = Flask(__name__) |
|
|
CORS(app) |
|
|
CORS(app) |
|
@ -144,24 +144,30 @@ def main(question, title, top): |
|
|
index = faiss.IndexFlatIP(d) # buid the index |
|
|
index = faiss.IndexFlatIP(d) # buid the index |
|
|
|
|
|
|
|
|
# 查找向量 |
|
|
# 查找向量 |
|
|
vector_path = f"data_np/{title_dan}.npy" |
|
|
# vector_path = f"data_np/{title_dan}.npy" |
|
|
vectors = np.load(vector_path) |
|
|
# vectors = np.load(vector_path) |
|
|
|
|
|
|
|
|
data_str = pd.read_csv(f"data_file/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist() |
|
|
data_str = pd.read_csv(f"data_file_res/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist() |
|
|
|
|
|
|
|
|
|
|
|
data_str_valid = [] |
|
|
|
|
|
for i in data_str: |
|
|
|
|
|
if i[3] == True: |
|
|
|
|
|
data_str_valid.append(i) |
|
|
|
|
|
|
|
|
|
|
|
data_str_vectors_list = [] |
|
|
|
|
|
for i in data_str_valid: |
|
|
|
|
|
data_str_vectors_list.append(eval(i[-1])) |
|
|
|
|
|
vectors = np.array(data_str_vectors_list) |
|
|
index.add(vectors) |
|
|
index.add(vectors) |
|
|
D, I = index.search(embs, int(top)) |
|
|
D, I = index.search(embs, int(top)) |
|
|
print(I) |
|
|
print(I) |
|
|
|
|
|
|
|
|
reference_list = [] |
|
|
reference_list = [] |
|
|
for i,j in zip(I[0], D[0]): |
|
|
for i,j in zip(I[0], D[0]): |
|
|
print("i", i) |
|
|
reference_list.append([data_str_valid[i], j]) |
|
|
print("data_str[i]", data_str[i]) |
|
|
|
|
|
reference_list.append([data_str[i], j]) |
|
|
|
|
|
|
|
|
|
|
|
for i,j in enumerate(reference_list): |
|
|
for i,j in enumerate(reference_list): |
|
|
paper_list_str += "第{}篇\n{},此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][1], j[1]) |
|
|
paper_list_str += "第{}篇\n{},此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][1], j[1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
''' |
|
|
''' |
|
|
构造prompt |
|
|
构造prompt |
|
|
''' |
|
|
''' |
|
@ -216,28 +222,102 @@ def model_generate_stream(prompt): |
|
|
async def handle_websocket(websocket): |
|
|
async def handle_websocket(websocket): |
|
|
print("客户端已连接") |
|
|
print("客户端已连接") |
|
|
try: |
|
|
try: |
|
|
while True: |
|
|
async for message in websocket: |
|
|
message = await websocket.recv() |
|
|
try: |
|
|
data = json.loads(message) |
|
|
data = json.loads(message) |
|
|
texts = data.get("texts") |
|
|
texts = data.get("texts") |
|
|
title = data.get("title") |
|
|
title = data.get("title") |
|
|
top = data.get("top") |
|
|
top = data.get("top") |
|
|
print("收到消息:", message) |
|
|
print(f"收到消息: {texts}") |
|
|
|
|
|
|
|
|
|
|
|
# 获取响应 |
|
|
response = main(texts, title, top) |
|
|
response = main(texts, title, top) |
|
|
# response = message + "111" |
|
|
|
|
|
|
|
|
# 发送响应 |
|
|
for char in response: |
|
|
for char in response: |
|
|
await websocket.send(char) |
|
|
await websocket.send(char) |
|
|
# await asyncio.sleep(0.3) |
|
|
await asyncio.sleep(0.001) # 小延迟避免发送过快 |
|
|
|
|
|
|
|
|
|
|
|
# 发送完成标记 |
|
|
await websocket.send("[DONE]") |
|
|
await websocket.send("[DONE]") |
|
|
|
|
|
print("消息发送完成") |
|
|
|
|
|
|
|
|
|
|
|
except json.JSONDecodeError: |
|
|
|
|
|
await websocket.send('{"error": "Invalid JSON format"}') |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"处理消息时发生错误: {e}") |
|
|
|
|
|
await websocket.send('{"error": "Internal server error"}') |
|
|
|
|
|
|
|
|
except websockets.exceptions.ConnectionClosed: |
|
|
except websockets.exceptions.ConnectionClosed: |
|
|
print("客户端断开连接") |
|
|
print("客户端断开连接") |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"WebSocket处理异常: {e}") |
|
|
|
|
|
finally: |
|
|
|
|
|
print("连接处理结束") |
|
|
|
|
|
|
|
|
async def main_api(): |
|
|
async def main_api(): |
|
|
async with websockets.serve(handle_websocket, "0.0.0.0", 27001): |
|
|
try: |
|
|
print("WebSocket 服务器已启动,监听端口 27001") |
|
|
ssl_context = None |
|
|
await asyncio.Future() # 永久运行 |
|
|
|
|
|
|
|
|
# 检查证书文件是否存在 |
|
|
|
|
|
ssl_cert = "server.crt" |
|
|
|
|
|
ssl_key = "server.key" |
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists(ssl_cert) and os.path.exists(ssl_key): |
|
|
|
|
|
try: |
|
|
|
|
|
# 创建SSL上下文 |
|
|
|
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) |
|
|
|
|
|
# 加载证书链 |
|
|
|
|
|
ssl_context.load_cert_chain(ssl_cert, ssl_key) |
|
|
|
|
|
# 禁用主机名验证(对于自签名证书) |
|
|
|
|
|
ssl_context.check_hostname = False |
|
|
|
|
|
ssl_context.verify_mode = ssl.CERT_NONE |
|
|
|
|
|
print("SSL证书已加载,使用WSS协议") |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"SSL证书加载失败: {e}") |
|
|
|
|
|
print("将使用WS协议") |
|
|
|
|
|
ssl_context = None |
|
|
|
|
|
else: |
|
|
|
|
|
print("警告:SSL证书文件未找到,将使用WS协议") |
|
|
|
|
|
ssl_context = None |
|
|
|
|
|
|
|
|
|
|
|
# 创建服务器 |
|
|
|
|
|
server = await websockets.serve( |
|
|
|
|
|
handle_websocket, |
|
|
|
|
|
"0.0.0.0", |
|
|
|
|
|
27001, |
|
|
|
|
|
ssl=ssl_context, |
|
|
|
|
|
ping_interval=30, # 添加ping间隔 |
|
|
|
|
|
ping_timeout=30, # 添加ping超时 |
|
|
|
|
|
close_timeout=30 # 添加关闭超时 |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if ssl_context: |
|
|
|
|
|
print("WSS服务器已启动: wss://0.0.0.0:27001") |
|
|
|
|
|
else: |
|
|
|
|
|
print("WS服务器已启动: ws://0.0.0.0:27001") |
|
|
|
|
|
|
|
|
|
|
|
# 保持服务器运行 |
|
|
|
|
|
await server.wait_closed() |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"服务器启动失败: {e}") |
|
|
|
|
|
import traceback |
|
|
|
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main_api()) # 正确启动事件循环 |
|
|
# 设置更详细的事件循环调试 |
|
|
|
|
|
import logging |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
asyncio.run(main_api()) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
|
|
|
print("服务器被用户中断") |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"服务器运行错误: {e}") |
|
|
|
|
|
import traceback |
|
|
|
|
|
|
|
|
|
|
|
traceback.print_exc() |
|
|