diff --git a/README.md b/README.md index e69de29..1e2c494 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,19 @@ +### 知识库增删改查 +``` +python main.py +``` + +### 知识库流式问答socket +``` +python main_scokt.py +``` + +### 仿deepseek流式问答 socket (非 rag相关内容) +``` +python main_scoket_deepspeek.py +``` + +### 具体接口可以参考接口文档11 + 接口文档.docx + https://console-docs.apipost.cn/preview/55b7d541588142d1/f4645422856c695a + https://console-docs.apipost.cn/preview/f03a79d844523711/2f4079d715d28b32 \ No newline at end of file diff --git a/ceshi_scokt.py b/ceshi_scokt.py new file mode 100644 index 0000000..ba14324 --- /dev/null +++ b/ceshi_scokt.py @@ -0,0 +1,85 @@ +from openai import OpenAI + +openai_api_key = "token-abc123" +openai_api_base = "http://127.0.0.1:12011/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + + +def model_generate_stream(prompt): + messages = [ + {"role": "user", "content": prompt} + ] + + stream = client.chat.completions.create(model=model, + messages=messages, + stream=True) + printed_reasoning_content = False + printed_content = False + + for chunk in stream: + reasoning_content = None + content = None + # Check the content is reasoning_content or content + if hasattr(chunk.choices[0].delta, "reasoning_content"): + reasoning_content = chunk.choices[0].delta.reasoning_content + elif hasattr(chunk.choices[0].delta, "content"): + content = chunk.choices[0].delta.content + + if reasoning_content is not None: + if not printed_reasoning_content: + printed_reasoning_content = True + print("reasoning_content:", end="", flush=True) + print(reasoning_content, end="", flush=True) + elif content is not None: + if not printed_content: + printed_content = True + print("\ncontent:", end="", flush=True) + # Extract and print the content + # print(content, end="", flush=True) + print(content) + yield content +# if __name__ == '__main__': +# for i in model_generate_stream("你好"): +# print(i) + + +import asyncio +import websockets +import json + + +async def handle_websocket(websocket): + print("客户端已连接") + try: + while True: + message = await websocket.recv() + print("收到消息:", message) + + data = json.loads(message) + texts = data.get("texts") + title = data.get("title") + top = data.get("top") + + response = model_generate_stream(texts) + # response = message + "111" + for char in response: + await websocket.send(char) + # await asyncio.sleep(0.3) + await websocket.send("[DONE]") + except websockets.exceptions.ConnectionClosed: + print("客户端断开连接") + +async def main(): + async with websockets.serve(handle_websocket, "0.0.0.0", 5500): + print("WebSocket 服务器已启动,监听端口 5500") + await asyncio.Future() # 永久运行 + +if __name__ == "__main__": + asyncio.run(main()) # 正确启动事件循环 \ No newline at end of file diff --git a/main.py b/main.py index 306b1bf..5afd6d1 100644 --- a/main.py +++ b/main.py @@ -2,21 +2,43 @@ # 按 Shift+F10 执行或将其替换为您的代码。 # 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" import faiss import numpy as np from tqdm import tqdm from sentence_transformers import SentenceTransformer import requests import time -from flask import Flask, jsonify -from flask import request +from flask import Flask, jsonify, Response, request +from openai import OpenAI +from flask_cors import CORS import pandas as pd +import concurrent.futures +import json +import torch +import uuid - +# flask配置 app = Flask(__name__) +CORS(app) app.config["JSON_AS_ASCII"] = False -model = SentenceTransformer('/home/majiahui/project/models-llm/bge-large-zh-v1.5') +openai_api_key = "token-abc123" +openai_api_base = "http://127.0.0.1:12011/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +# 模型配置 +models = client.models.list() +model = models.data[0].id +# model = "1" +model_encode = SentenceTransformer('/home/majiahui/project/models-llm/bge-large-zh-v1.5') + +# 提示配置 propmt_connect = '''我是一名中医,你是一个中医的医生的助理,我的患者有一个症状,症状如下: {} 根据这些症状,我通过查找资料,{} @@ -26,7 +48,7 @@ propmt_connect_ziliao = '''在“{}”资料中,有如下相关内容: {}''' -def dialog_line_parse(url, text): +def dialog_line_parse(text): """ 将数据输入模型进行分析并输出结果 :param url: 模型url @@ -34,8 +56,9 @@ def dialog_line_parse(url, text): :return: 模型返回结果 """ + url_predict = "http://118.178.228.101:12004/predict" response = requests.post( - url, + url_predict, json=text, timeout=100000 ) @@ -49,46 +72,261 @@ def dialog_line_parse(url, text): # ) print("【{}】 Failed to get a proper response from remote " "server. Status Code: {}. Response: {}" - "".format(url, response.status_code, response.text)) + "".format(url_predict, response.status_code, response.text)) return {} + # ['choices'][0]['message']['content'] + # + # text = text['messages'][0]['content'] + # return_text = { + # 'code': 200, + # 'id': "1", + # 'object': 0, + # 'created': 0, + # 'model': 0, + # 'choices': [ + # { + # 'index': 0, + # 'message': { + # 'role': 'assistant', + # 'content': text + # }, + # 'logprobs': None, + # 'finish_reason': 'stop' + # } + # ], + # 'usage': 0, + # 'system_fingerprint': 0 + # } + # return return_text + def shengcehng_array(data): - embs = model.encode(data, normalize_embeddings=True) + ''' + 模型生成向量 + :param data: + :return: + ''' + embs = model_encode.encode(data, normalize_embeddings=True) return embs -def Building_vector_database(type, name, df): - data_ndarray = np.empty((0, 1024)) - for sen in df: - data_ndarray = np.concatenate((data_ndarray, shengcehng_array([sen[0]]))) - print("data_ndarray.shape", data_ndarray.shape) - print("data_ndarray.shape", data_ndarray.shape) - np.save(f'data_np/{name}.npy', data_ndarray) +def Building_vector_database(title, df): + ''' + 次函数暂时弃用 + :param title: + :param df: + :return: + ''' + # 加载需要处理的数据(有效且未向量化) + to_process = df[(df["有效"] == True) & (df["已向量化"] == False)] + + if len(to_process) == 0: + print("无新增数据需要向量化") + return + + # 生成向量数组 + new_vectors = shengcehng_array(to_process["总结"].tolist()) # 假设这是你的向量生成函数 + # 加载现有向量库和索引 + vector_path = f"data_np/{title}.npy" + index_path = f"data_np/{title}_index.json" -def ulit_request_file(file, title): - file_name = file.filename - file_name_save = "data_file/{}.csv".format(title) - file.save(file_name_save) + vectors = np.load(vector_path) if os.path.exists(vector_path) else np.empty((0, 1024)) + index_data = {} + if os.path.exists(index_path): + with open(index_path, "r") as f: + index_data = json.load(f) - # try: - # with open(file_name_save, encoding="gbk") as f: - # content = f.read() - # except: - # with open(file_name_save, encoding="utf-8") as f: - # content = f.read() - # elif file_name.split(".")[-1] == "docx": - # content = docx2txt.process(file_name_save) + # 更新索引和向量库 + start_idx = len(vectors) + vectors = np.vstack([vectors, new_vectors]) - # content_list = [i for i in content.split("\n")] - df = pd.read_csv(file_name_save, sep="\t", encoding="utf-8").values.tolist() + for i, (_, row) in enumerate(to_process.iterrows()): + index_data[row["ID"]] = { + "row": start_idx + i, + "valid": True + } - return df + # 保存数据 + np.save(vector_path, vectors) + with open(index_path, "w") as f: + json.dump(index_data, f) + # 标记已向量化 + df.loc[to_process.index, "已向量化"] = True + df.to_csv(f"data_file_res/{title}.csv", sep="\t", index=False) -def main(question, db_type, top): + +def delete_data(title, new_id): + ''' + 假删除,只是标记有效无效 + :param title: + :param new_id: + :return: + ''' + new_id = str(new_id) + # 更新CSV标记 + csv_path = f"data_file_res/{title}.csv" + df = pd.read_csv(csv_path, sep="\t") + # df.loc[df["ID"] == new_id, "有效"] = False + df.loc[df['ID'] == new_id, "有效"] = False + df.to_csv(csv_path, sep="\t", index=False) + return "删除完成" + + +def check_file_exists(file_path): + """ + 检查文件是否存在 + + 参数: + file_path (str): 要检查的文件路径 + + 返回: + bool: 文件存在返回True,否则返回False + """ + return os.path.isfile(file_path) + + +def ulit_request_file(sentence, title, zongjie): + ''' + 上传文件,生成固定内容,"ID", "正文", "总结", "有效", "向量" + :param sentence: + :param title: + :param zongjie: + :return: + ''' + file_name_res_save = f"data_file_res/{title}.csv" + + # 初始化或读取CSV文件,如果存在文件,读取文件,并添加行, + # 如果不存在文件,新建DataFrame + if os.path.exists(file_name_res_save): + df = pd.read_csv(file_name_res_save, sep="\t") + # 检查是否已存在相同正文 + if sentence in df["正文"].values: + if zongjie == None: + return "正文已存在,跳过处理" + else: + result = df[df['正文'] == sentence] + id_ = result['ID'].values[0] + print(id_) + return ulit_request_file_zongjie(id_, sentence, zongjie, title) + else: + df = pd.DataFrame(columns=["ID", "正文", "总结", "有效", "向量"]) + + # 添加新数据(生成唯一ID) + if zongjie == None: + id_ = str(uuid.uuid1()) + new_row = { + "ID": id_, + "正文": sentence, + "总结": None, + "有效": True, + "向量": None + } + df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True) + + # 需要根据不同的项目修改提示,目的是精简内容,为了方便匹配 + data_dan = { + "model": "gpt-4-turbo", + "messages": [{ + "role": "user", + "content": f"{sentence}\n以上这条中可能包含了一些病情或者症状,请帮我归纳这条中所对应的病情或者症状是哪些,总结出来,不需要很长,简单归纳即可,直接输出症状或者病情,可以包含一些形容词来辅助描述,不需要有辅助词汇" + }], + "top_p": 0.9, + "temperature": 0.3 + } + results = dialog_line_parse(data_dan) + summary = results['choices'][0]['message']['content'] + + # 这是你的向量生成函数,来生成总结的词汇的向量 + new_vectors = shengcehng_array([summary]) + df.loc[df['ID'] == id_, '总结'] = summary + df.loc[df['ID'] == id_, '向量'] = str(new_vectors[0].tolist()) + + else: + id_ = str(uuid.uuid1()) + new_row = { + "ID": id_ , + "正文": sentence, + "总结": zongjie, + "有效": True, + "向量": None + } + df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True) + new_vectors = shengcehng_array([zongjie]) # 假设这是你的向量生成函数 + df.loc[df['ID'] == id_, '总结'] = zongjie + df.loc[df['ID'] == id_, '向量'] = str(new_vectors[0].tolist()) + + # 保存更新后的CSV + df.to_csv(file_name_res_save, sep="\t", index=False) + return "上传完成" + + +def ulit_request_file_zongjie(new_id, sentence, zongjie, title): + new_id = str(new_id) + print(new_id) + print(type(new_id)) + file_name_res_save = f"data_file_res/{title}.csv" + + # 初始化或读取CSV文件 + + df = pd.read_csv(file_name_res_save, sep="\t") + df.loc[df['ID'] == new_id, '正文'] = sentence + if zongjie == None: + pass + else: + df.loc[df['ID'] == new_id, '总结'] = zongjie + new_vectors = shengcehng_array([zongjie]) # 假设这是你的向量生成函数 + df.loc[df['ID'] == new_id, '向量'] = str(new_vectors[0].tolist()) + + # 保存更新后的CSV + df.to_csv(file_name_res_save, sep="\t", index=False) + return "修改完成" + + +def ulit_request_file_check(title): + file_name_res_save = f"data_file_res/{title}.csv" + + # 初始化或读取CSV文件 + + # 初始化或读取CSV文件 + if os.path.exists(file_name_res_save): + df = pd.read_csv(file_name_res_save, sep="\t").values.tolist() + data_new = [] + for i in df: + if i[3] == True: + data_new.append([i[0], i[1], i[2]]) + return data_new + else: + return "无可展示文件" + + +def ulit_request_file_check_dan(new_id, title): + new_id = str(new_id) + file_name_res_save = f"data_file_res/{title}.csv" + + # 初始化或读取CSV文件 + + # 初始化或读取CSV文件 + if os.path.exists(file_name_res_save): + df = pd.read_csv(file_name_res_save, sep="\t") + zhengwen = df.loc[df['ID'] == new_id, '正文'].values + zongjie = df.loc[df['ID'] == new_id, '总结'].values + # 输出结果 + if len(zhengwen) > 0: + if df.loc[df['ID'] == new_id, '有效'].values == True: + return [new_id, zhengwen[0], zongjie[0]] + else: + return "未找到对应的ID" + else: + return "未找到对应的ID" + else: + return "无可展示文件" + + +def main(question, title, top): db_dict = { "1": "yetianshi" } @@ -114,26 +352,38 @@ def main(question, db_type, top): 根据提问匹配上下文 ''' d = 1024 - db_type_list = db_type.split(",") + db_type_list = title.split(",") paper_list_str = "" - for i in db_type_list: + for title_dan in db_type_list: embs = shengcehng_array([question]) index = faiss.IndexFlatIP(d) # buid the index - data_np = np.load(f"data_np/{i}.npy") - # data_str = open(f"data_file/{i}.txt").read().split("\n") - data_str = pd.read_csv(f"data_file/{i}.csv", sep="\t", encoding="utf-8").values.tolist() - index.add(data_np) + + # 查找向量 + # vector_path = f"data_np/{title_dan}.npy" + # vectors = np.load(vector_path) + + data_str = pd.read_csv(f"data_file_res/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist() + + data_str_valid = [] + for i in data_str: + if i[3] == True: + data_str_valid.append(i) + + data_str_vectors_list = [] + for i in data_str_valid: + data_str_vectors_list.append(eval(i[-1])) + vectors = np.array(data_str_vectors_list) + index.add(vectors) D, I = index.search(embs, int(top)) print(I) reference_list = [] for i,j in zip(I[0], D[0]): - reference_list.append([data_str[i], j]) + reference_list.append([data_str_valid[i], j]) for i,j in enumerate(reference_list): - paper_list_str += "第{}篇\n{},此篇文章的转发数为{},评论数为{},点赞数为{}\n,此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][0], j[0][1], j[0][2], j[0][3], j[1]) - + paper_list_str += "第{}篇\n{},此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][1], j[1]) ''' 构造prompt ''' @@ -147,61 +397,86 @@ def main(question, db_type, top): ''' 生成回答 ''' - url_predict = "http://192.168.31.74:26000/predict" - url_search = "http://192.168.31.74:26000/search" + return model_generate_stream(propmt_connect_input) + + +def model_generate_stream(prompt): + messages = [ + {"role": "user", "content": prompt} + ] + + stream = client.chat.completions.create(model=model, + messages=messages, + stream=True) + printed_reasoning_content = False + printed_content = False + + for chunk in stream: + reasoning_content = None + content = None + # Check the content is reasoning_content or content + if hasattr(chunk.choices[0].delta, "reasoning_content"): + reasoning_content = chunk.choices[0].delta.reasoning_content + elif hasattr(chunk.choices[0].delta, "content"): + content = chunk.choices[0].delta.content + + if reasoning_content is not None: + if not printed_reasoning_content: + printed_reasoning_content = True + print("reasoning_content:", end="", flush=True) + print(reasoning_content, end="", flush=True) + elif content is not None: + if not printed_content: + printed_content = True + print("\ncontent:", end="", flush=True) + # Extract and print the content + # print(content, end="", flush=True) + print(content, end="") + yield content - # data = { - # "content": propmt_connect_input - # } - data = { - "content": propmt_connect_input, - "model": "qwq-32", - "top_p": 0.9, - "temperature": 0.6 - } - res = dialog_line_parse(url_predict, data) - id_ = res["texts"]["id"] +@app.route("/upload_file_check", methods=["POST"]) +def upload_file_check(): + print(request.remote_addr) + sentence = request.form.get('sentence') + title = request.form.get("title") + new_id = request.form.get("id") + zongjie = request.form.get("zongjie") + state = request.form.get("state") + ''' + { + "1": "csv", + "2": "xlsx", + "3": "txt", + "4": "pdf" + } + ''' + # 增 + state_res = "" + if state == "1": + state_res = ulit_request_file(sentence, title, zongjie) - data = { - "id": id_ - } - while True: - res = dialog_line_parse(url_search, data) - if res["code"] == 200: - break - else: - time.sleep(1) - spilt_str = "" - think, response = str(res["text"]).split(spilt_str) - return think, response + # 删 + elif state == "2": + state_res = delete_data(title, new_id) + # 改 + elif state == "3": + state_res = ulit_request_file_zongjie(new_id, sentence, zongjie,title) -@app.route("/upload_file", methods=["POST"]) -def upload_file(): - print(request.remote_addr) - file = request.files.get('file') - title = request.form.get("title") - df = ulit_request_file(file, title) - Building_vector_database("1", title, df) - return_json = { - "code": 200, - "info": "上传完成" - } - return jsonify(return_json) # 返回结果 + # 查 + elif state == "4": + state_res = ulit_request_file_check(title) + # 通过uuid查单条数据 + elif state == "5": + ulit_request_file_check_dan(new_id, title) + state_res = "" -@app.route("/upload_file_check", methods=["POST"]) -def upload_file_check(): - print(request.remote_addr) - file = request.files.get('file') - title = request.form.get("title") - df = ulit_request_file(file, title) - Building_vector_database("1", title, df) return_json = { "code": 200, - "info": "上传完成" + "info": state_res } return jsonify(return_json) # 返回结果 @@ -210,15 +485,10 @@ def upload_file_check(): def search(): print(request.remote_addr) texts = request.json["texts"] - text_type = request.json["text_type"] + title = request.json["title"] top = request.json["top"] - think, response = main(texts, text_type, top) - return_json = { - "code": 200, - "think": think, - "response": response - } - return jsonify(return_json) # 返回结果 + response = main(texts, title, top) + return Response(response, mimetype='text/plain; charset=utf-8') # 返回结果 if __name__ == "__main__": diff --git a/main_scokt.py b/main_scokt.py new file mode 100644 index 0000000..1ac0300 --- /dev/null +++ b/main_scokt.py @@ -0,0 +1,357 @@ +# 这是一个示例 Python 脚本。 + +# 按 Shift+F10 执行或将其替换为您的代码。 +# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 +import os +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +import faiss +import numpy as np +from tqdm import tqdm +from sentence_transformers import SentenceTransformer +import requests +import time +from flask import Flask, jsonify, Response, request +from flask_cors import CORS +import pandas as pd +import redis +from openai import OpenAI +import asyncio +import websockets +import json +import ssl +import pathlib + +app = Flask(__name__) +CORS(app) +app.config["JSON_AS_ASCII"] = False + +pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=1, password="zhicheng123*") +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query' +db_key_querying = 'querying' +batch_size = 32 + +openai_api_key = "token-abc123" +openai_api_base = "http://127.0.0.1:12011/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id +# model = "1" +model_encode = SentenceTransformer('/home/zhangbaoxun/project/models-llm/bge-large-zh-v1.5') +propmt_connect = '''我是一名中医,你是一个中医的医生的助理,我的患者有一个症状,症状如下: +{} +根据这些症状,我通过查找资料,{} +请根据上面的这些资料和方子,并根据每篇文章的转发数确定文章的重要程度,转发数越高的文章,最终答案的参考度越高,反之越低。根据患者的症状和上面的文章的资料的重要程度以及文章和症状的匹配程度,帮我开出正确的药方和治疗方案''' + +propmt_connect_ziliao = '''在“{}”资料中,有如下相关内容: +{}''' + + +def dialog_line_parse(text): + """ + 将数据输入模型进行分析并输出结果 + :param url: 模型url + :param text: 进入模型的数据 + :return: 模型返回结果 + """ + + url_predict = "http://118.178.228.101:12004/predict" + response = requests.post( + url_predict, + json=text, + timeout=100000 + ) + if response.status_code == 200: + return response.json() + else: + # logger.error( + # "【{}】 Failed to get a proper response from remote " + # "server. Status Code: {}. Response: {}" + # "".format(url, response.status_code, response.text) + # ) + print("【{}】 Failed to get a proper response from remote " + "server. Status Code: {}. Response: {}" + "".format(url_predict, response.status_code, response.text)) + return {} + + # ['choices'][0]['message']['content'] + # + # text = text['messages'][0]['content'] + # return_text = { + # 'code': 200, + # 'id': "1", + # 'object': 0, + # 'created': 0, + # 'model': 0, + # 'choices': [ + # { + # 'index': 0, + # 'message': { + # 'role': 'assistant', + # 'content': text + # }, + # 'logprobs': None, + # 'finish_reason': 'stop' + # } + # ], + # 'usage': 0, + # 'system_fingerprint': 0 + # } + # return return_text + + +def shengcehng_array(data): + embs = model_encode.encode(data, normalize_embeddings=True) + return embs + + +def main(question, title, top): + ''' + 主函数,用来匹配句子放到prompt中,生成回答 + :param question: + :param title: + :param top: + :return: + ''' + db_dict = { + "1": "yetianshi" + } + ''' + 定义文件路径 + ''' + + ''' + 加载文件 + ''' + + ''' + 文本分割 + ''' + + ''' + 构建向量数据库 + 1. 正常匹配 + 2. 把文本使用大模型生成一个问题之后再进行匹配 + ''' + + ''' + 根据提问匹配上下文 + ''' + d = 1024 + db_type_list = title.split(",") + + paper_list_str = "" + for title_dan in db_type_list: + embs = shengcehng_array([question]) + index = faiss.IndexFlatIP(d) # buid the index + + # 查找向量 + # vector_path = f"data_np/{title_dan}.npy" + # vectors = np.load(vector_path) + + # 读取向量文件 csv文件结构: + # ID + # 正文 + # 总结 + # 有效 + # 向量 + data_str = pd.read_csv(f"data_file_res/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist() + + data_str_valid = [] + for i in data_str: + # i[3] == True 说明数据没有被删除,如果是false说明被删除 + if i[3] == True: + data_str_valid.append(i) + + # 把有效数据的向量汇总出来 + data_str_vectors_list = [] + for i in data_str_valid: + data_str_vectors_list.append(eval(i[-1])) + + # 拼接成向量矩阵 + vectors = np.array(data_str_vectors_list) + index.add(vectors) + + # 使用faiss找到最相似向量 + D, I = index.search(embs, int(top)) + print(I) + + reference_list = [] + for i,j in zip(I[0], D[0]): + # 添加 csv对应的数据 data_str_valid[i]表示 csv中一行的所有数据 ID 正文 总结 有效 向量 以及 j表示相关度是多少 + reference_list.append([data_str_valid[i], j]) + + for i,j in enumerate(reference_list): + paper_list_str += "第{}篇\n{},此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][1], j[1]) + ''' + 构造prompt + ''' + print("paper_list_str", paper_list_str) + propmt_connect_ziliao_input = [] + for i in db_type_list: + propmt_connect_ziliao_input.append(propmt_connect_ziliao.format(i, paper_list_str)) + + # 构造输入问题,把上面的都展示出来 + propmt_connect_ziliao_input_str = ",".join(propmt_connect_ziliao_input) + propmt_connect_input = propmt_connect.format(question, propmt_connect_ziliao_input_str) + + ''' + 生成回答,这个model_generate_stream可以根据需要指定模型 + ''' + return model_generate_stream(propmt_connect_input) + + +def model_generate_stream(prompt): + messages = [ + {"role": "user", "content": prompt} + ] + + stream = client.chat.completions.create(model=model, + messages=messages, + stream=True) + printed_reasoning_content = False + printed_content = False + + for chunk in stream: + reasoning_content = None + content = None + # Check the content is reasoning_content or content + if hasattr(chunk.choices[0].delta, "reasoning_content"): + reasoning_content = chunk.choices[0].delta.reasoning_content + elif hasattr(chunk.choices[0].delta, "content"): + content = chunk.choices[0].delta.content + + if reasoning_content is not None: + if not printed_reasoning_content: + printed_reasoning_content = True + print("reasoning_content:", end="", flush=True) + print(reasoning_content, end="", flush=True) + elif content is not None: + if not printed_content: + printed_content = True + print("\ncontent:", end="", flush=True) + # Extract and print the content + # print(content, end="", flush=True) + print(content) + yield content + + +async def handle_websocket(websocket): + print("客户端已连接") + try: + async for message in websocket: + try: + data = json.loads(message) + texts = data.get("texts") + title = data.get("title") + top = data.get("top") + print(f"收到消息: {texts}") + + # 获取响应 + response = main(texts, title, top) + + # 发送响应 + for char in response: + await websocket.send(char) + await asyncio.sleep(0.001) # 小延迟避免发送过快 + + # 发送完成标记 + await websocket.send("[DONE]") + print("消息发送完成") + + except json.JSONDecodeError: + await websocket.send('{"error": "Invalid JSON format"}') + except Exception as e: + print(f"处理消息时发生错误: {e}") + await websocket.send('{"error": "Internal server error"}') + + except websockets.exceptions.ConnectionClosed: + print("客户端断开连接") + except Exception as e: + print(f"WebSocket处理异常: {e}") + finally: + print("连接处理结束") + +async def main_api(): + try: + # 是否加载 + ssl_context = None + # wss服务开关 True是打开wss服务 + wss_bool = False + # 检查证书文件是否存在 + + ssl_cert = "yitongtang66.com_ca_chains.crt" + ssl_key = "yitongtang66.com.key" + + # ssl_cert = "yizherenxin.cn.crt" + # ssl_key = "yizherenxin.cn.key" + + # ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + # ssl_context.load_cert_chain(ssl_cert, ssl_key) + # ssl_context.check_hostname = False # 必须禁用主机名验证 + # ssl_context.verify_mode = ssl.CERT_NONE # 不验证证书 + + if wss_bool == True: + try: + # 创建SSL上下文 + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + # 加载证书链 + ssl_context.load_cert_chain(ssl_cert, ssl_key) + # 禁用主机名验证(对于自签名证书) + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + print("SSL证书已加载,使用WSS协议") + except Exception as e: + print(f"SSL证书加载失败: {e}") + print("将使用WS协议") + ssl_context = None + else: + print("警告:SSL证书文件未找到,将使用WS协议") + ssl_context = None + + # 创建服务器 + server = await websockets.serve( + handle_websocket, + "0.0.0.0", + 27001, + ssl=ssl_context, + ping_interval=30, # 添加ping间隔 + ping_timeout=30, # 添加ping超时 + close_timeout=30 # 添加关闭超时 + ) + + # 启动27001端口 + if ssl_context: + print("WSS服务器已启动: wss://0.0.0.0:27001") + else: + print("WS服务器已启动: ws://0.0.0.0:27001") + + # 保持服务器运行 + await server.wait_closed() + + except Exception as e: + print(f"服务器启动失败: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + # 设置更详细的事件循环调试 + import logging + + logging.basicConfig(level=logging.INFO) + + # 启动服务 + try: + asyncio.run(main_api()) + except KeyboardInterrupt: + print("服务器被用户中断") + except Exception as e: + print(f"服务器运行错误: {e}")