From 2f48a11ac8dddc99218466c9cdfb630793121ad9 Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Mon, 1 Sep 2025 10:17:49 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95scokt?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=EF=BC=8C=E5=B9=B6=E6=9B=B4=E6=94=B9=E6=B5=81?= =?UTF-8?q?=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main_scokt.py | 148 ++++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 114 insertions(+), 34 deletions(-) diff --git a/main_scokt.py b/main_scokt.py index 0dbad79..3437755 100644 --- a/main_scokt.py +++ b/main_scokt.py @@ -10,15 +10,15 @@ from sentence_transformers import SentenceTransformer import requests import time from flask import Flask, jsonify, Response, request -from openai import OpenAI from flask_cors import CORS import pandas as pd -import concurrent.futures -import json -from threading import Thread import redis +from openai import OpenAI import asyncio import websockets +import json +import ssl +import pathlib app = Flask(__name__) CORS(app) @@ -144,24 +144,30 @@ def main(question, title, top): index = faiss.IndexFlatIP(d) # buid the index # 查找向量 - vector_path = f"data_np/{title_dan}.npy" - vectors = np.load(vector_path) + # vector_path = f"data_np/{title_dan}.npy" + # vectors = np.load(vector_path) - data_str = pd.read_csv(f"data_file/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist() + data_str = pd.read_csv(f"data_file_res/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist() + + data_str_valid = [] + for i in data_str: + if i[3] == True: + data_str_valid.append(i) + + data_str_vectors_list = [] + for i in data_str_valid: + data_str_vectors_list.append(eval(i[-1])) + vectors = np.array(data_str_vectors_list) index.add(vectors) D, I = index.search(embs, int(top)) print(I) reference_list = [] - for i, j in zip(I[0], D[0]): - print("i", i) - print("data_str[i]", data_str[i]) - reference_list.append([data_str[i], j]) - - for i, j in enumerate(reference_list): - paper_list_str += "第{}篇\n{},此篇文章跟问题的相关度为{}%\n".format(str(i + 1), j[0][1], j[1]) - + for i,j in zip(I[0], D[0]): + reference_list.append([data_str_valid[i], j]) + for i,j in enumerate(reference_list): + paper_list_str += "第{}篇\n{},此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][1], j[1]) ''' 构造prompt ''' @@ -216,28 +222,102 @@ def model_generate_stream(prompt): async def handle_websocket(websocket): print("客户端已连接") try: - while True: - message = await websocket.recv() - data = json.loads(message) - texts = data.get("texts") - title = data.get("title") - top = data.get("top") - print("收到消息:", message) - - response = main(texts, title, top) - # response = message + "111" - for char in response: - await websocket.send(char) - # await asyncio.sleep(0.3) - await websocket.send("[DONE]") + async for message in websocket: + try: + data = json.loads(message) + texts = data.get("texts") + title = data.get("title") + top = data.get("top") + print(f"收到消息: {texts}") + + # 获取响应 + response = main(texts, title, top) + + # 发送响应 + for char in response: + await websocket.send(char) + await asyncio.sleep(0.001) # 小延迟避免发送过快 + + # 发送完成标记 + await websocket.send("[DONE]") + print("消息发送完成") + + except json.JSONDecodeError: + await websocket.send('{"error": "Invalid JSON format"}') + except Exception as e: + print(f"处理消息时发生错误: {e}") + await websocket.send('{"error": "Internal server error"}') + except websockets.exceptions.ConnectionClosed: print("客户端断开连接") - + except Exception as e: + print(f"WebSocket处理异常: {e}") + finally: + print("连接处理结束") async def main_api(): - async with websockets.serve(handle_websocket, "0.0.0.0", 27001): - print("WebSocket 服务器已启动,监听端口 27001") - await asyncio.Future() # 永久运行 + try: + ssl_context = None + + # 检查证书文件是否存在 + ssl_cert = "server.crt" + ssl_key = "server.key" + + if os.path.exists(ssl_cert) and os.path.exists(ssl_key): + try: + # 创建SSL上下文 + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + # 加载证书链 + ssl_context.load_cert_chain(ssl_cert, ssl_key) + # 禁用主机名验证(对于自签名证书) + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + print("SSL证书已加载,使用WSS协议") + except Exception as e: + print(f"SSL证书加载失败: {e}") + print("将使用WS协议") + ssl_context = None + else: + print("警告:SSL证书文件未找到,将使用WS协议") + ssl_context = None + + # 创建服务器 + server = await websockets.serve( + handle_websocket, + "0.0.0.0", + 27001, + ssl=ssl_context, + ping_interval=30, # 添加ping间隔 + ping_timeout=30, # 添加ping超时 + close_timeout=30 # 添加关闭超时 + ) + + if ssl_context: + print("WSS服务器已启动: wss://0.0.0.0:27001") + else: + print("WS服务器已启动: ws://0.0.0.0:27001") + + # 保持服务器运行 + await server.wait_closed() + + except Exception as e: + print(f"服务器启动失败: {e}") + import traceback + traceback.print_exc() + if __name__ == "__main__": - asyncio.run(main_api()) # 正确启动事件循环 + # 设置更详细的事件循环调试 + import logging + + logging.basicConfig(level=logging.INFO) + + try: + asyncio.run(main_api()) + except KeyboardInterrupt: + print("服务器被用户中断") + except Exception as e: + print(f"服务器运行错误: {e}") + import traceback + + traceback.print_exc()