Browse Source

增加测试scokt请求,并更改流程

div_测试
majiahui@haimaqingfan.com 2 weeks ago
parent
commit
2f48a11ac8
  1. 148
      main_scokt.py

148
main_scokt.py

@ -10,15 +10,15 @@ from sentence_transformers import SentenceTransformer
import requests import requests
import time import time
from flask import Flask, jsonify, Response, request from flask import Flask, jsonify, Response, request
from openai import OpenAI
from flask_cors import CORS from flask_cors import CORS
import pandas as pd import pandas as pd
import concurrent.futures
import json
from threading import Thread
import redis import redis
from openai import OpenAI
import asyncio import asyncio
import websockets import websockets
import json
import ssl
import pathlib
app = Flask(__name__) app = Flask(__name__)
CORS(app) CORS(app)
@ -144,24 +144,30 @@ def main(question, title, top):
index = faiss.IndexFlatIP(d) # buid the index index = faiss.IndexFlatIP(d) # buid the index
# 查找向量 # 查找向量
vector_path = f"data_np/{title_dan}.npy" # vector_path = f"data_np/{title_dan}.npy"
vectors = np.load(vector_path) # vectors = np.load(vector_path)
data_str = pd.read_csv(f"data_file/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist() data_str = pd.read_csv(f"data_file_res/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist()
data_str_valid = []
for i in data_str:
if i[3] == True:
data_str_valid.append(i)
data_str_vectors_list = []
for i in data_str_valid:
data_str_vectors_list.append(eval(i[-1]))
vectors = np.array(data_str_vectors_list)
index.add(vectors) index.add(vectors)
D, I = index.search(embs, int(top)) D, I = index.search(embs, int(top))
print(I) print(I)
reference_list = [] reference_list = []
for i, j in zip(I[0], D[0]): for i,j in zip(I[0], D[0]):
print("i", i) reference_list.append([data_str_valid[i], j])
print("data_str[i]", data_str[i])
reference_list.append([data_str[i], j])
for i, j in enumerate(reference_list):
paper_list_str += "{}\n{},此篇文章跟问题的相关度为{}%\n".format(str(i + 1), j[0][1], j[1])
for i,j in enumerate(reference_list):
paper_list_str += "{}\n{},此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][1], j[1])
''' '''
构造prompt 构造prompt
''' '''
@ -216,28 +222,102 @@ def model_generate_stream(prompt):
async def handle_websocket(websocket): async def handle_websocket(websocket):
print("客户端已连接") print("客户端已连接")
try: try:
while True: async for message in websocket:
message = await websocket.recv() try:
data = json.loads(message) data = json.loads(message)
texts = data.get("texts") texts = data.get("texts")
title = data.get("title") title = data.get("title")
top = data.get("top") top = data.get("top")
print("收到消息:", message) print(f"收到消息: {texts}")
response = main(texts, title, top) # 获取响应
# response = message + "111" response = main(texts, title, top)
for char in response:
await websocket.send(char) # 发送响应
# await asyncio.sleep(0.3) for char in response:
await websocket.send("[DONE]") await websocket.send(char)
await asyncio.sleep(0.001) # 小延迟避免发送过快
# 发送完成标记
await websocket.send("[DONE]")
print("消息发送完成")
except json.JSONDecodeError:
await websocket.send('{"error": "Invalid JSON format"}')
except Exception as e:
print(f"处理消息时发生错误: {e}")
await websocket.send('{"error": "Internal server error"}')
except websockets.exceptions.ConnectionClosed: except websockets.exceptions.ConnectionClosed:
print("客户端断开连接") print("客户端断开连接")
except Exception as e:
print(f"WebSocket处理异常: {e}")
finally:
print("连接处理结束")
async def main_api(): async def main_api():
async with websockets.serve(handle_websocket, "0.0.0.0", 27001): try:
print("WebSocket 服务器已启动,监听端口 27001") ssl_context = None
await asyncio.Future() # 永久运行
# 检查证书文件是否存在
ssl_cert = "server.crt"
ssl_key = "server.key"
if os.path.exists(ssl_cert) and os.path.exists(ssl_key):
try:
# 创建SSL上下文
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# 加载证书链
ssl_context.load_cert_chain(ssl_cert, ssl_key)
# 禁用主机名验证(对于自签名证书)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
print("SSL证书已加载,使用WSS协议")
except Exception as e:
print(f"SSL证书加载失败: {e}")
print("将使用WS协议")
ssl_context = None
else:
print("警告:SSL证书文件未找到,将使用WS协议")
ssl_context = None
# 创建服务器
server = await websockets.serve(
handle_websocket,
"0.0.0.0",
27001,
ssl=ssl_context,
ping_interval=30, # 添加ping间隔
ping_timeout=30, # 添加ping超时
close_timeout=30 # 添加关闭超时
)
if ssl_context:
print("WSS服务器已启动: wss://0.0.0.0:27001")
else:
print("WS服务器已启动: ws://0.0.0.0:27001")
# 保持服务器运行
await server.wait_closed()
except Exception as e:
print(f"服务器启动失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main_api()) # 正确启动事件循环 # 设置更详细的事件循环调试
import logging
logging.basicConfig(level=logging.INFO)
try:
asyncio.run(main_api())
except KeyboardInterrupt:
print("服务器被用户中断")
except Exception as e:
print(f"服务器运行错误: {e}")
import traceback
traceback.print_exc()

Loading…
Cancel
Save