Compare commits

...

13 Commits

Author SHA1 Message Date
majiahui@haimaqingfan.com 634fd049d5 增加测试scokt请求,并更改流程,线上跑通 2 weeks ago
majiahui@haimaqingfan.com c07cc61c8b 增加测试scokt请求,并更改流程,线上跑通 2 weeks ago
majiahui@haimaqingfan.com c751f371a7 增加测试scokt请求,并更改流程,线上跑通 2 weeks ago
majiahui@haimaqingfan.com dab07a99e4 增加测试scokt请求,并更改流程,线上跑通 1 month ago
majiahui@haimaqingfan.com cfb587a02e 增加测试scokt请求,并更改流程,线上跑通 1 month ago
majiahui@haimaqingfan.com 2f48a11ac8 增加测试scokt请求,并更改流程 1 month ago
majiahui@haimaqingfan.com 8d7708f7b0 增加测试scokt请求,并更改流程 2 months ago
majiahui@haimaqingfan.com 2c91b46a66 增加测试scokt请求,并更改流程 2 months ago
majiahui@haimaqingfan.com a83c265f4e 增加测试scokt请求 6 months ago
majiahui@haimaqingfan.com 9c1cc4c768 增加scokt请求 6 months ago
majiahui@haimaqingfan.com 90af48046f 向量融合到csv中,并支持增删改查操作 6 months ago
majiahui@haimaqingfan.com 1e93757254 修复批量上传出错问题,并发承载力更强 6 months ago
majiahui@haimaqingfan.com 823923c927 单条数据上传 6 months ago
  1. 19
      README.md
  2. 85
      ceshi_scokt.py
  3. 456
      main.py
  4. 357
      main_scokt.py

19
README.md

@ -0,0 +1,19 @@
### 知识库增删改查
```
python main.py
```
### 知识库流式问答socket
```
python main_scokt.py
```
### 仿deepseek流式问答 socket (非 rag相关内容)
```
python main_scoket_deepspeek.py
```
### 具体接口可以参考接口文档11
接口文档.docx
https://console-docs.apipost.cn/preview/55b7d541588142d1/f4645422856c695a
https://console-docs.apipost.cn/preview/f03a79d844523711/2f4079d715d28b32

85
ceshi_scokt.py

@ -0,0 +1,85 @@
from openai import OpenAI
openai_api_key = "token-abc123"
openai_api_base = "http://127.0.0.1:12011/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
def model_generate_stream(prompt):
messages = [
{"role": "user", "content": prompt}
]
stream = client.chat.completions.create(model=model,
messages=messages,
stream=True)
printed_reasoning_content = False
printed_content = False
for chunk in stream:
reasoning_content = None
content = None
# Check the content is reasoning_content or content
if hasattr(chunk.choices[0].delta, "reasoning_content"):
reasoning_content = chunk.choices[0].delta.reasoning_content
elif hasattr(chunk.choices[0].delta, "content"):
content = chunk.choices[0].delta.content
if reasoning_content is not None:
if not printed_reasoning_content:
printed_reasoning_content = True
print("reasoning_content:", end="", flush=True)
print(reasoning_content, end="", flush=True)
elif content is not None:
if not printed_content:
printed_content = True
print("\ncontent:", end="", flush=True)
# Extract and print the content
# print(content, end="", flush=True)
print(content)
yield content
# if __name__ == '__main__':
# for i in model_generate_stream("你好"):
# print(i)
import asyncio
import websockets
import json
async def handle_websocket(websocket):
print("客户端已连接")
try:
while True:
message = await websocket.recv()
print("收到消息:", message)
data = json.loads(message)
texts = data.get("texts")
title = data.get("title")
top = data.get("top")
response = model_generate_stream(texts)
# response = message + "111"
for char in response:
await websocket.send(char)
# await asyncio.sleep(0.3)
await websocket.send("[DONE]")
except websockets.exceptions.ConnectionClosed:
print("客户端断开连接")
async def main():
async with websockets.serve(handle_websocket, "0.0.0.0", 5500):
print("WebSocket 服务器已启动,监听端口 5500")
await asyncio.Future() # 永久运行
if __name__ == "__main__":
asyncio.run(main()) # 正确启动事件循环

456
main.py

@ -2,21 +2,43 @@
# 按 Shift+F10 执行或将其替换为您的代码。
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import faiss
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import requests
import time
from flask import Flask, jsonify
from flask import request
from flask import Flask, jsonify, Response, request
from openai import OpenAI
from flask_cors import CORS
import pandas as pd
import concurrent.futures
import json
import torch
import uuid
# flask配置
app = Flask(__name__)
CORS(app)
app.config["JSON_AS_ASCII"] = False
model = SentenceTransformer('/home/majiahui/project/models-llm/bge-large-zh-v1.5')
openai_api_key = "token-abc123"
openai_api_base = "http://127.0.0.1:12011/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
# 模型配置
models = client.models.list()
model = models.data[0].id
# model = "1"
model_encode = SentenceTransformer('/home/majiahui/project/models-llm/bge-large-zh-v1.5')
# 提示配置
propmt_connect = '''我是一名中医,你是一个中医的医生的助理,我的患者有一个症状,症状如下:
{}
根据这些症状我通过查找资料{}
@ -26,7 +48,7 @@ propmt_connect_ziliao = '''在“{}”资料中,有如下相关内容:
{}'''
def dialog_line_parse(url, text):
def dialog_line_parse(text):
"""
将数据输入模型进行分析并输出结果
:param url: 模型url
@ -34,8 +56,9 @@ def dialog_line_parse(url, text):
:return: 模型返回结果
"""
url_predict = "http://118.178.228.101:12004/predict"
response = requests.post(
url,
url_predict,
json=text,
timeout=100000
)
@ -49,46 +72,261 @@ def dialog_line_parse(url, text):
# )
print("{}】 Failed to get a proper response from remote "
"server. Status Code: {}. Response: {}"
"".format(url, response.status_code, response.text))
"".format(url_predict, response.status_code, response.text))
return {}
# ['choices'][0]['message']['content']
#
# text = text['messages'][0]['content']
# return_text = {
# 'code': 200,
# 'id': "1",
# 'object': 0,
# 'created': 0,
# 'model': 0,
# 'choices': [
# {
# 'index': 0,
# 'message': {
# 'role': 'assistant',
# 'content': text
# },
# 'logprobs': None,
# 'finish_reason': 'stop'
# }
# ],
# 'usage': 0,
# 'system_fingerprint': 0
# }
# return return_text
def shengcehng_array(data):
embs = model.encode(data, normalize_embeddings=True)
'''
模型生成向量
:param data:
:return:
'''
embs = model_encode.encode(data, normalize_embeddings=True)
return embs
def Building_vector_database(type, name, df):
data_ndarray = np.empty((0, 1024))
for sen in df:
data_ndarray = np.concatenate((data_ndarray, shengcehng_array([sen[0]])))
print("data_ndarray.shape", data_ndarray.shape)
print("data_ndarray.shape", data_ndarray.shape)
np.save(f'data_np/{name}.npy', data_ndarray)
def Building_vector_database(title, df):
'''
次函数暂时弃用
:param title:
:param df:
:return:
'''
# 加载需要处理的数据(有效且未向量化)
to_process = df[(df["有效"] == True) & (df["已向量化"] == False)]
if len(to_process) == 0:
print("无新增数据需要向量化")
return
# 生成向量数组
new_vectors = shengcehng_array(to_process["总结"].tolist()) # 假设这是你的向量生成函数
# 加载现有向量库和索引
vector_path = f"data_np/{title}.npy"
index_path = f"data_np/{title}_index.json"
vectors = np.load(vector_path) if os.path.exists(vector_path) else np.empty((0, 1024))
index_data = {}
if os.path.exists(index_path):
with open(index_path, "r") as f:
index_data = json.load(f)
# 更新索引和向量库
start_idx = len(vectors)
vectors = np.vstack([vectors, new_vectors])
for i, (_, row) in enumerate(to_process.iterrows()):
index_data[row["ID"]] = {
"row": start_idx + i,
"valid": True
}
# 保存数据
np.save(vector_path, vectors)
with open(index_path, "w") as f:
json.dump(index_data, f)
# 标记已向量化
df.loc[to_process.index, "已向量化"] = True
df.to_csv(f"data_file_res/{title}.csv", sep="\t", index=False)
def delete_data(title, new_id):
'''
假删除只是标记有效无效
:param title:
:param new_id:
:return:
'''
new_id = str(new_id)
# 更新CSV标记
csv_path = f"data_file_res/{title}.csv"
df = pd.read_csv(csv_path, sep="\t")
# df.loc[df["ID"] == new_id, "有效"] = False
df.loc[df['ID'] == new_id, "有效"] = False
df.to_csv(csv_path, sep="\t", index=False)
return "删除完成"
def check_file_exists(file_path):
"""
检查文件是否存在
参数:
file_path (str): 要检查的文件路径
返回:
bool: 文件存在返回True否则返回False
"""
return os.path.isfile(file_path)
def ulit_request_file(sentence, title, zongjie):
'''
上传文件生成固定内容"ID", "正文", "总结", "有效", "向量"
:param sentence:
:param title:
:param zongjie:
:return:
'''
file_name_res_save = f"data_file_res/{title}.csv"
# 初始化或读取CSV文件,如果存在文件,读取文件,并添加行,
# 如果不存在文件,新建DataFrame
if os.path.exists(file_name_res_save):
df = pd.read_csv(file_name_res_save, sep="\t")
# 检查是否已存在相同正文
if sentence in df["正文"].values:
if zongjie == None:
return "正文已存在,跳过处理"
else:
result = df[df['正文'] == sentence]
id_ = result['ID'].values[0]
print(id_)
return ulit_request_file_zongjie(id_, sentence, zongjie, title)
else:
df = pd.DataFrame(columns=["ID", "正文", "总结", "有效", "向量"])
# 添加新数据(生成唯一ID)
if zongjie == None:
id_ = str(uuid.uuid1())
new_row = {
"ID": id_,
"正文": sentence,
"总结": None,
"有效": True,
"向量": None
}
df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
# 需要根据不同的项目修改提示,目的是精简内容,为了方便匹配
data_dan = {
"model": "gpt-4-turbo",
"messages": [{
"role": "user",
"content": f"{sentence}\n以上这条中可能包含了一些病情或者症状,请帮我归纳这条中所对应的病情或者症状是哪些,总结出来,不需要很长,简单归纳即可,直接输出症状或者病情,可以包含一些形容词来辅助描述,不需要有辅助词汇"
}],
"top_p": 0.9,
"temperature": 0.3
}
results = dialog_line_parse(data_dan)
summary = results['choices'][0]['message']['content']
# 这是你的向量生成函数,来生成总结的词汇的向量
new_vectors = shengcehng_array([summary])
df.loc[df['ID'] == id_, '总结'] = summary
df.loc[df['ID'] == id_, '向量'] = str(new_vectors[0].tolist())
else:
id_ = str(uuid.uuid1())
new_row = {
"ID": id_ ,
"正文": sentence,
"总结": zongjie,
"有效": True,
"向量": None
}
df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
new_vectors = shengcehng_array([zongjie]) # 假设这是你的向量生成函数
df.loc[df['ID'] == id_, '总结'] = zongjie
df.loc[df['ID'] == id_, '向量'] = str(new_vectors[0].tolist())
# 保存更新后的CSV
df.to_csv(file_name_res_save, sep="\t", index=False)
return "上传完成"
def ulit_request_file_zongjie(new_id, sentence, zongjie, title):
new_id = str(new_id)
print(new_id)
print(type(new_id))
file_name_res_save = f"data_file_res/{title}.csv"
# 初始化或读取CSV文件
df = pd.read_csv(file_name_res_save, sep="\t")
df.loc[df['ID'] == new_id, '正文'] = sentence
if zongjie == None:
pass
else:
df.loc[df['ID'] == new_id, '总结'] = zongjie
new_vectors = shengcehng_array([zongjie]) # 假设这是你的向量生成函数
df.loc[df['ID'] == new_id, '向量'] = str(new_vectors[0].tolist())
# 保存更新后的CSV
df.to_csv(file_name_res_save, sep="\t", index=False)
return "修改完成"
def ulit_request_file(file, title):
file_name = file.filename
file_name_save = "data_file/{}.csv".format(title)
file.save(file_name_save)
def ulit_request_file_check(title):
file_name_res_save = f"data_file_res/{title}.csv"
# try:
# with open(file_name_save, encoding="gbk") as f:
# content = f.read()
# except:
# with open(file_name_save, encoding="utf-8") as f:
# content = f.read()
# elif file_name.split(".")[-1] == "docx":
# content = docx2txt.process(file_name_save)
# 初始化或读取CSV文件
# 初始化或读取CSV文件
if os.path.exists(file_name_res_save):
df = pd.read_csv(file_name_res_save, sep="\t").values.tolist()
data_new = []
for i in df:
if i[3] == True:
data_new.append([i[0], i[1], i[2]])
return data_new
else:
return "无可展示文件"
# content_list = [i for i in content.split("\n")]
df = pd.read_csv(file_name_save, sep="\t", encoding="utf-8").values.tolist()
return df
def ulit_request_file_check_dan(new_id, title):
new_id = str(new_id)
file_name_res_save = f"data_file_res/{title}.csv"
# 初始化或读取CSV文件
def main(question, db_type, top):
# 初始化或读取CSV文件
if os.path.exists(file_name_res_save):
df = pd.read_csv(file_name_res_save, sep="\t")
zhengwen = df.loc[df['ID'] == new_id, '正文'].values
zongjie = df.loc[df['ID'] == new_id, '总结'].values
# 输出结果
if len(zhengwen) > 0:
if df.loc[df['ID'] == new_id, '有效'].values == True:
return [new_id, zhengwen[0], zongjie[0]]
else:
return "未找到对应的ID"
else:
return "未找到对应的ID"
else:
return "无可展示文件"
def main(question, title, top):
db_dict = {
"1": "yetianshi"
}
@ -114,26 +352,38 @@ def main(question, db_type, top):
根据提问匹配上下文
'''
d = 1024
db_type_list = db_type.split(",")
db_type_list = title.split(",")
paper_list_str = ""
for i in db_type_list:
for title_dan in db_type_list:
embs = shengcehng_array([question])
index = faiss.IndexFlatIP(d) # buid the index
data_np = np.load(f"data_np/{i}.npy")
# data_str = open(f"data_file/{i}.txt").read().split("\n")
data_str = pd.read_csv(f"data_file/{i}.csv", sep="\t", encoding="utf-8").values.tolist()
index.add(data_np)
# 查找向量
# vector_path = f"data_np/{title_dan}.npy"
# vectors = np.load(vector_path)
data_str = pd.read_csv(f"data_file_res/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist()
data_str_valid = []
for i in data_str:
if i[3] == True:
data_str_valid.append(i)
data_str_vectors_list = []
for i in data_str_valid:
data_str_vectors_list.append(eval(i[-1]))
vectors = np.array(data_str_vectors_list)
index.add(vectors)
D, I = index.search(embs, int(top))
print(I)
reference_list = []
for i,j in zip(I[0], D[0]):
reference_list.append([data_str[i], j])
reference_list.append([data_str_valid[i], j])
for i,j in enumerate(reference_list):
paper_list_str += "{}\n{},此篇文章的转发数为{},评论数为{},点赞数为{}\n,此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][0], j[0][1], j[0][2], j[0][3], j[1])
paper_list_str += "{}\n{},此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][1], j[1])
'''
构造prompt
'''
@ -147,61 +397,86 @@ def main(question, db_type, top):
'''
生成回答
'''
url_predict = "http://192.168.31.74:26000/predict"
url_search = "http://192.168.31.74:26000/search"
return model_generate_stream(propmt_connect_input)
def model_generate_stream(prompt):
messages = [
{"role": "user", "content": prompt}
]
stream = client.chat.completions.create(model=model,
messages=messages,
stream=True)
printed_reasoning_content = False
printed_content = False
for chunk in stream:
reasoning_content = None
content = None
# Check the content is reasoning_content or content
if hasattr(chunk.choices[0].delta, "reasoning_content"):
reasoning_content = chunk.choices[0].delta.reasoning_content
elif hasattr(chunk.choices[0].delta, "content"):
content = chunk.choices[0].delta.content
if reasoning_content is not None:
if not printed_reasoning_content:
printed_reasoning_content = True
print("reasoning_content:", end="", flush=True)
print(reasoning_content, end="", flush=True)
elif content is not None:
if not printed_content:
printed_content = True
print("\ncontent:", end="", flush=True)
# Extract and print the content
# print(content, end="", flush=True)
print(content, end="")
yield content
# data = {
# "content": propmt_connect_input
# }
data = {
"content": propmt_connect_input,
"model": "qwq-32",
"top_p": 0.9,
"temperature": 0.6
}
res = dialog_line_parse(url_predict, data)
id_ = res["texts"]["id"]
data = {
"id": id_
@app.route("/upload_file_check", methods=["POST"])
def upload_file_check():
print(request.remote_addr)
sentence = request.form.get('sentence')
title = request.form.get("title")
new_id = request.form.get("id")
zongjie = request.form.get("zongjie")
state = request.form.get("state")
'''
{
"1": "csv",
"2": "xlsx",
"3": "txt",
"4": "pdf"
}
'''
# 增
state_res = ""
if state == "1":
state_res = ulit_request_file(sentence, title, zongjie)
while True:
res = dialog_line_parse(url_search, data)
if res["code"] == 200:
break
else:
time.sleep(1)
spilt_str = "</think>"
think, response = str(res["text"]).split(spilt_str)
return think, response
# 删
elif state == "2":
state_res = delete_data(title, new_id)
@app.route("/upload_file", methods=["POST"])
def upload_file():
print(request.remote_addr)
file = request.files.get('file')
title = request.form.get("title")
df = ulit_request_file(file, title)
Building_vector_database("1", title, df)
return_json = {
"code": 200,
"info": "上传完成"
}
return jsonify(return_json) # 返回结果
# 改
elif state == "3":
state_res = ulit_request_file_zongjie(new_id, sentence, zongjie,title)
# 查
elif state == "4":
state_res = ulit_request_file_check(title)
# 通过uuid查单条数据
elif state == "5":
ulit_request_file_check_dan(new_id, title)
state_res = ""
@app.route("/upload_file_check", methods=["POST"])
def upload_file_check():
print(request.remote_addr)
file = request.files.get('file')
title = request.form.get("title")
df = ulit_request_file(file, title)
Building_vector_database("1", title, df)
return_json = {
"code": 200,
"info": "上传完成"
"info": state_res
}
return jsonify(return_json) # 返回结果
@ -210,15 +485,10 @@ def upload_file_check():
def search():
print(request.remote_addr)
texts = request.json["texts"]
text_type = request.json["text_type"]
title = request.json["title"]
top = request.json["top"]
think, response = main(texts, text_type, top)
return_json = {
"code": 200,
"think": think,
"response": response
}
return jsonify(return_json) # 返回结果
response = main(texts, title, top)
return Response(response, mimetype='text/plain; charset=utf-8') # 返回结果
if __name__ == "__main__":

357
main_scokt.py

@ -0,0 +1,357 @@
# 这是一个示例 Python 脚本。
# 按 Shift+F10 执行或将其替换为您的代码。
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import faiss
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import requests
import time
from flask import Flask, jsonify, Response, request
from flask_cors import CORS
import pandas as pd
import redis
from openai import OpenAI
import asyncio
import websockets
import json
import ssl
import pathlib
app = Flask(__name__)
CORS(app)
app.config["JSON_AS_ASCII"] = False
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=1, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query'
db_key_querying = 'querying'
batch_size = 32
openai_api_key = "token-abc123"
openai_api_base = "http://127.0.0.1:12011/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
# model = "1"
model_encode = SentenceTransformer('/home/zhangbaoxun/project/models-llm/bge-large-zh-v1.5')
propmt_connect = '''我是一名中医,你是一个中医的医生的助理,我的患者有一个症状,症状如下:
{}
根据这些症状我通过查找资料{}
请根据上面的这些资料和方子并根据每篇文章的转发数确定文章的重要程度转发数越高的文章最终答案的参考度越高反之越低根据患者的症状和上面的文章的资料的重要程度以及文章和症状的匹配程度帮我开出正确的药方和治疗方案'''
propmt_connect_ziliao = '''在“{}”资料中,有如下相关内容:
{}'''
def dialog_line_parse(text):
"""
将数据输入模型进行分析并输出结果
:param url: 模型url
:param text: 进入模型的数据
:return: 模型返回结果
"""
url_predict = "http://118.178.228.101:12004/predict"
response = requests.post(
url_predict,
json=text,
timeout=100000
)
if response.status_code == 200:
return response.json()
else:
# logger.error(
# "【{}】 Failed to get a proper response from remote "
# "server. Status Code: {}. Response: {}"
# "".format(url, response.status_code, response.text)
# )
print("{}】 Failed to get a proper response from remote "
"server. Status Code: {}. Response: {}"
"".format(url_predict, response.status_code, response.text))
return {}
# ['choices'][0]['message']['content']
#
# text = text['messages'][0]['content']
# return_text = {
# 'code': 200,
# 'id': "1",
# 'object': 0,
# 'created': 0,
# 'model': 0,
# 'choices': [
# {
# 'index': 0,
# 'message': {
# 'role': 'assistant',
# 'content': text
# },
# 'logprobs': None,
# 'finish_reason': 'stop'
# }
# ],
# 'usage': 0,
# 'system_fingerprint': 0
# }
# return return_text
def shengcehng_array(data):
embs = model_encode.encode(data, normalize_embeddings=True)
return embs
def main(question, title, top):
'''
主函数用来匹配句子放到prompt中生成回答
:param question:
:param title:
:param top:
:return:
'''
db_dict = {
"1": "yetianshi"
}
'''
定义文件路径
'''
'''
加载文件
'''
'''
文本分割
'''
'''
构建向量数据库
1. 正常匹配
2. 把文本使用大模型生成一个问题之后再进行匹配
'''
'''
根据提问匹配上下文
'''
d = 1024
db_type_list = title.split(",")
paper_list_str = ""
for title_dan in db_type_list:
embs = shengcehng_array([question])
index = faiss.IndexFlatIP(d) # buid the index
# 查找向量
# vector_path = f"data_np/{title_dan}.npy"
# vectors = np.load(vector_path)
# 读取向量文件 csv文件结构:
# ID
# 正文
# 总结
# 有效
# 向量
data_str = pd.read_csv(f"data_file_res/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist()
data_str_valid = []
for i in data_str:
# i[3] == True 说明数据没有被删除,如果是false说明被删除
if i[3] == True:
data_str_valid.append(i)
# 把有效数据的向量汇总出来
data_str_vectors_list = []
for i in data_str_valid:
data_str_vectors_list.append(eval(i[-1]))
# 拼接成向量矩阵
vectors = np.array(data_str_vectors_list)
index.add(vectors)
# 使用faiss找到最相似向量
D, I = index.search(embs, int(top))
print(I)
reference_list = []
for i,j in zip(I[0], D[0]):
# 添加 csv对应的数据 data_str_valid[i]表示 csv中一行的所有数据 ID 正文 总结 有效 向量 以及 j表示相关度是多少
reference_list.append([data_str_valid[i], j])
for i,j in enumerate(reference_list):
paper_list_str += "{}\n{},此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][1], j[1])
'''
构造prompt
'''
print("paper_list_str", paper_list_str)
propmt_connect_ziliao_input = []
for i in db_type_list:
propmt_connect_ziliao_input.append(propmt_connect_ziliao.format(i, paper_list_str))
# 构造输入问题,把上面的都展示出来
propmt_connect_ziliao_input_str = "".join(propmt_connect_ziliao_input)
propmt_connect_input = propmt_connect.format(question, propmt_connect_ziliao_input_str)
'''
生成回答这个model_generate_stream可以根据需要指定模型
'''
return model_generate_stream(propmt_connect_input)
def model_generate_stream(prompt):
messages = [
{"role": "user", "content": prompt}
]
stream = client.chat.completions.create(model=model,
messages=messages,
stream=True)
printed_reasoning_content = False
printed_content = False
for chunk in stream:
reasoning_content = None
content = None
# Check the content is reasoning_content or content
if hasattr(chunk.choices[0].delta, "reasoning_content"):
reasoning_content = chunk.choices[0].delta.reasoning_content
elif hasattr(chunk.choices[0].delta, "content"):
content = chunk.choices[0].delta.content
if reasoning_content is not None:
if not printed_reasoning_content:
printed_reasoning_content = True
print("reasoning_content:", end="", flush=True)
print(reasoning_content, end="", flush=True)
elif content is not None:
if not printed_content:
printed_content = True
print("\ncontent:", end="", flush=True)
# Extract and print the content
# print(content, end="", flush=True)
print(content)
yield content
async def handle_websocket(websocket):
print("客户端已连接")
try:
async for message in websocket:
try:
data = json.loads(message)
texts = data.get("texts")
title = data.get("title")
top = data.get("top")
print(f"收到消息: {texts}")
# 获取响应
response = main(texts, title, top)
# 发送响应
for char in response:
await websocket.send(char)
await asyncio.sleep(0.001) # 小延迟避免发送过快
# 发送完成标记
await websocket.send("[DONE]")
print("消息发送完成")
except json.JSONDecodeError:
await websocket.send('{"error": "Invalid JSON format"}')
except Exception as e:
print(f"处理消息时发生错误: {e}")
await websocket.send('{"error": "Internal server error"}')
except websockets.exceptions.ConnectionClosed:
print("客户端断开连接")
except Exception as e:
print(f"WebSocket处理异常: {e}")
finally:
print("连接处理结束")
async def main_api():
try:
# 是否加载
ssl_context = None
# wss服务开关 True是打开wss服务
wss_bool = False
# 检查证书文件是否存在
ssl_cert = "yitongtang66.com_ca_chains.crt"
ssl_key = "yitongtang66.com.key"
# ssl_cert = "yizherenxin.cn.crt"
# ssl_key = "yizherenxin.cn.key"
# ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# ssl_context.load_cert_chain(ssl_cert, ssl_key)
# ssl_context.check_hostname = False # 必须禁用主机名验证
# ssl_context.verify_mode = ssl.CERT_NONE # 不验证证书
if wss_bool == True:
try:
# 创建SSL上下文
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# 加载证书链
ssl_context.load_cert_chain(ssl_cert, ssl_key)
# 禁用主机名验证(对于自签名证书)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
print("SSL证书已加载,使用WSS协议")
except Exception as e:
print(f"SSL证书加载失败: {e}")
print("将使用WS协议")
ssl_context = None
else:
print("警告:SSL证书文件未找到,将使用WS协议")
ssl_context = None
# 创建服务器
server = await websockets.serve(
handle_websocket,
"0.0.0.0",
27001,
ssl=ssl_context,
ping_interval=30, # 添加ping间隔
ping_timeout=30, # 添加ping超时
close_timeout=30 # 添加关闭超时
)
# 启动27001端口
if ssl_context:
print("WSS服务器已启动: wss://0.0.0.0:27001")
else:
print("WS服务器已启动: ws://0.0.0.0:27001")
# 保持服务器运行
await server.wait_closed()
except Exception as e:
print(f"服务器启动失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
# 设置更详细的事件循环调试
import logging
logging.basicConfig(level=logging.INFO)
# 启动服务
try:
asyncio.run(main_api())
except KeyboardInterrupt:
print("服务器被用户中断")
except Exception as e:
print(f"服务器运行错误: {e}")
Loading…
Cancel
Save