Compare commits
10 Commits
Author | SHA1 | Date |
---|---|---|
![]() |
dab07a99e4 | 1 week ago |
![]() |
cfb587a02e | 1 week ago |
![]() |
2f48a11ac8 | 3 weeks ago |
![]() |
8d7708f7b0 | 4 weeks ago |
![]() |
2c91b46a66 | 1 month ago |
![]() |
a83c265f4e | 5 months ago |
![]() |
9c1cc4c768 | 5 months ago |
![]() |
90af48046f | 5 months ago |
![]() |
1e93757254 | 5 months ago |
![]() |
823923c927 | 5 months ago |
3 changed files with 720 additions and 91 deletions
@ -0,0 +1,85 @@ |
|||||
|
from openai import OpenAI |
||||
|
|
||||
|
openai_api_key = "token-abc123" |
||||
|
openai_api_base = "http://127.0.0.1:12011/v1" |
||||
|
|
||||
|
client = OpenAI( |
||||
|
api_key=openai_api_key, |
||||
|
base_url=openai_api_base, |
||||
|
) |
||||
|
|
||||
|
models = client.models.list() |
||||
|
model = models.data[0].id |
||||
|
|
||||
|
|
||||
|
def model_generate_stream(prompt): |
||||
|
messages = [ |
||||
|
{"role": "user", "content": prompt} |
||||
|
] |
||||
|
|
||||
|
stream = client.chat.completions.create(model=model, |
||||
|
messages=messages, |
||||
|
stream=True) |
||||
|
printed_reasoning_content = False |
||||
|
printed_content = False |
||||
|
|
||||
|
for chunk in stream: |
||||
|
reasoning_content = None |
||||
|
content = None |
||||
|
# Check the content is reasoning_content or content |
||||
|
if hasattr(chunk.choices[0].delta, "reasoning_content"): |
||||
|
reasoning_content = chunk.choices[0].delta.reasoning_content |
||||
|
elif hasattr(chunk.choices[0].delta, "content"): |
||||
|
content = chunk.choices[0].delta.content |
||||
|
|
||||
|
if reasoning_content is not None: |
||||
|
if not printed_reasoning_content: |
||||
|
printed_reasoning_content = True |
||||
|
print("reasoning_content:", end="", flush=True) |
||||
|
print(reasoning_content, end="", flush=True) |
||||
|
elif content is not None: |
||||
|
if not printed_content: |
||||
|
printed_content = True |
||||
|
print("\ncontent:", end="", flush=True) |
||||
|
# Extract and print the content |
||||
|
# print(content, end="", flush=True) |
||||
|
print(content) |
||||
|
yield content |
||||
|
# if __name__ == '__main__': |
||||
|
# for i in model_generate_stream("你好"): |
||||
|
# print(i) |
||||
|
|
||||
|
|
||||
|
import asyncio |
||||
|
import websockets |
||||
|
import json |
||||
|
|
||||
|
|
||||
|
async def handle_websocket(websocket): |
||||
|
print("客户端已连接") |
||||
|
try: |
||||
|
while True: |
||||
|
message = await websocket.recv() |
||||
|
print("收到消息:", message) |
||||
|
|
||||
|
data = json.loads(message) |
||||
|
texts = data.get("texts") |
||||
|
title = data.get("title") |
||||
|
top = data.get("top") |
||||
|
|
||||
|
response = model_generate_stream(texts) |
||||
|
# response = message + "111" |
||||
|
for char in response: |
||||
|
await websocket.send(char) |
||||
|
# await asyncio.sleep(0.3) |
||||
|
await websocket.send("[DONE]") |
||||
|
except websockets.exceptions.ConnectionClosed: |
||||
|
print("客户端断开连接") |
||||
|
|
||||
|
async def main(): |
||||
|
async with websockets.serve(handle_websocket, "0.0.0.0", 5500): |
||||
|
print("WebSocket 服务器已启动,监听端口 5500") |
||||
|
await asyncio.Future() # 永久运行 |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
asyncio.run(main()) # 正确启动事件循环 |
@ -0,0 +1,331 @@ |
|||||
|
# 这是一个示例 Python 脚本。 |
||||
|
|
||||
|
# 按 Shift+F10 执行或将其替换为您的代码。 |
||||
|
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 |
||||
|
import os |
||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
||||
|
import faiss |
||||
|
import numpy as np |
||||
|
from tqdm import tqdm |
||||
|
from sentence_transformers import SentenceTransformer |
||||
|
import requests |
||||
|
import time |
||||
|
from flask import Flask, jsonify, Response, request |
||||
|
from flask_cors import CORS |
||||
|
import pandas as pd |
||||
|
import redis |
||||
|
from openai import OpenAI |
||||
|
import asyncio |
||||
|
import websockets |
||||
|
import json |
||||
|
import ssl |
||||
|
import pathlib |
||||
|
|
||||
|
app = Flask(__name__) |
||||
|
CORS(app) |
||||
|
app.config["JSON_AS_ASCII"] = False |
||||
|
|
||||
|
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=1, password="zhicheng123*") |
||||
|
redis_ = redis.Redis(connection_pool=pool, decode_responses=True) |
||||
|
|
||||
|
db_key_query = 'query' |
||||
|
db_key_querying = 'querying' |
||||
|
batch_size = 32 |
||||
|
|
||||
|
openai_api_key = "token-abc123" |
||||
|
openai_api_base = "http://127.0.0.1:12011/v1" |
||||
|
|
||||
|
client = OpenAI( |
||||
|
api_key=openai_api_key, |
||||
|
base_url=openai_api_base, |
||||
|
) |
||||
|
|
||||
|
models = client.models.list() |
||||
|
model = models.data[0].id |
||||
|
# model = "1" |
||||
|
model_encode = SentenceTransformer('/home/majiahui/project/models-llm/bge-large-zh-v1.5') |
||||
|
propmt_connect = '''我是一名中医,你是一个中医的医生的助理,我的患者有一个症状,症状如下: |
||||
|
{} |
||||
|
根据这些症状,我通过查找资料,{} |
||||
|
请根据上面的这些资料和方子,并根据每篇文章的转发数确定文章的重要程度,转发数越高的文章,最终答案的参考度越高,反之越低。根据患者的症状和上面的文章的资料的重要程度以及文章和症状的匹配程度,帮我开出正确的药方和治疗方案''' |
||||
|
|
||||
|
propmt_connect_ziliao = '''在“{}”资料中,有如下相关内容: |
||||
|
{}''' |
||||
|
|
||||
|
|
||||
|
def dialog_line_parse(text): |
||||
|
""" |
||||
|
将数据输入模型进行分析并输出结果 |
||||
|
:param url: 模型url |
||||
|
:param text: 进入模型的数据 |
||||
|
:return: 模型返回结果 |
||||
|
""" |
||||
|
|
||||
|
url_predict = "http://118.178.228.101:12004/predict" |
||||
|
response = requests.post( |
||||
|
url_predict, |
||||
|
json=text, |
||||
|
timeout=100000 |
||||
|
) |
||||
|
if response.status_code == 200: |
||||
|
return response.json() |
||||
|
else: |
||||
|
# logger.error( |
||||
|
# "【{}】 Failed to get a proper response from remote " |
||||
|
# "server. Status Code: {}. Response: {}" |
||||
|
# "".format(url, response.status_code, response.text) |
||||
|
# ) |
||||
|
print("【{}】 Failed to get a proper response from remote " |
||||
|
"server. Status Code: {}. Response: {}" |
||||
|
"".format(url_predict, response.status_code, response.text)) |
||||
|
return {} |
||||
|
|
||||
|
# ['choices'][0]['message']['content'] |
||||
|
# |
||||
|
# text = text['messages'][0]['content'] |
||||
|
# return_text = { |
||||
|
# 'code': 200, |
||||
|
# 'id': "1", |
||||
|
# 'object': 0, |
||||
|
# 'created': 0, |
||||
|
# 'model': 0, |
||||
|
# 'choices': [ |
||||
|
# { |
||||
|
# 'index': 0, |
||||
|
# 'message': { |
||||
|
# 'role': 'assistant', |
||||
|
# 'content': text |
||||
|
# }, |
||||
|
# 'logprobs': None, |
||||
|
# 'finish_reason': 'stop' |
||||
|
# } |
||||
|
# ], |
||||
|
# 'usage': 0, |
||||
|
# 'system_fingerprint': 0 |
||||
|
# } |
||||
|
# return return_text |
||||
|
|
||||
|
|
||||
|
def shengcehng_array(data): |
||||
|
embs = model_encode.encode(data, normalize_embeddings=True) |
||||
|
return embs |
||||
|
|
||||
|
|
||||
|
def main(question, title, top): |
||||
|
db_dict = { |
||||
|
"1": "yetianshi" |
||||
|
} |
||||
|
''' |
||||
|
定义文件路径 |
||||
|
''' |
||||
|
|
||||
|
''' |
||||
|
加载文件 |
||||
|
''' |
||||
|
|
||||
|
''' |
||||
|
文本分割 |
||||
|
''' |
||||
|
|
||||
|
''' |
||||
|
构建向量数据库 |
||||
|
1. 正常匹配 |
||||
|
2. 把文本使用大模型生成一个问题之后再进行匹配 |
||||
|
''' |
||||
|
|
||||
|
''' |
||||
|
根据提问匹配上下文 |
||||
|
''' |
||||
|
d = 1024 |
||||
|
db_type_list = title.split(",") |
||||
|
|
||||
|
paper_list_str = "" |
||||
|
for title_dan in db_type_list: |
||||
|
embs = shengcehng_array([question]) |
||||
|
index = faiss.IndexFlatIP(d) # buid the index |
||||
|
|
||||
|
# 查找向量 |
||||
|
# vector_path = f"data_np/{title_dan}.npy" |
||||
|
# vectors = np.load(vector_path) |
||||
|
|
||||
|
data_str = pd.read_csv(f"data_file_res/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist() |
||||
|
|
||||
|
data_str_valid = [] |
||||
|
for i in data_str: |
||||
|
if i[3] == True: |
||||
|
data_str_valid.append(i) |
||||
|
|
||||
|
data_str_vectors_list = [] |
||||
|
for i in data_str_valid: |
||||
|
data_str_vectors_list.append(eval(i[-1])) |
||||
|
vectors = np.array(data_str_vectors_list) |
||||
|
index.add(vectors) |
||||
|
D, I = index.search(embs, int(top)) |
||||
|
print(I) |
||||
|
|
||||
|
reference_list = [] |
||||
|
for i,j in zip(I[0], D[0]): |
||||
|
reference_list.append([data_str_valid[i], j]) |
||||
|
|
||||
|
for i,j in enumerate(reference_list): |
||||
|
paper_list_str += "第{}篇\n{},此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][1], j[1]) |
||||
|
''' |
||||
|
构造prompt |
||||
|
''' |
||||
|
print("paper_list_str", paper_list_str) |
||||
|
propmt_connect_ziliao_input = [] |
||||
|
for i in db_type_list: |
||||
|
propmt_connect_ziliao_input.append(propmt_connect_ziliao.format(i, paper_list_str)) |
||||
|
|
||||
|
propmt_connect_ziliao_input_str = ",".join(propmt_connect_ziliao_input) |
||||
|
propmt_connect_input = propmt_connect.format(question, propmt_connect_ziliao_input_str) |
||||
|
''' |
||||
|
生成回答 |
||||
|
''' |
||||
|
return model_generate_stream(propmt_connect_input) |
||||
|
|
||||
|
|
||||
|
def model_generate_stream(prompt): |
||||
|
messages = [ |
||||
|
{"role": "user", "content": prompt} |
||||
|
] |
||||
|
|
||||
|
stream = client.chat.completions.create(model=model, |
||||
|
messages=messages, |
||||
|
stream=True) |
||||
|
printed_reasoning_content = False |
||||
|
printed_content = False |
||||
|
|
||||
|
for chunk in stream: |
||||
|
reasoning_content = None |
||||
|
content = None |
||||
|
# Check the content is reasoning_content or content |
||||
|
if hasattr(chunk.choices[0].delta, "reasoning_content"): |
||||
|
reasoning_content = chunk.choices[0].delta.reasoning_content |
||||
|
elif hasattr(chunk.choices[0].delta, "content"): |
||||
|
content = chunk.choices[0].delta.content |
||||
|
|
||||
|
if reasoning_content is not None: |
||||
|
if not printed_reasoning_content: |
||||
|
printed_reasoning_content = True |
||||
|
print("reasoning_content:", end="", flush=True) |
||||
|
print(reasoning_content, end="", flush=True) |
||||
|
elif content is not None: |
||||
|
if not printed_content: |
||||
|
printed_content = True |
||||
|
print("\ncontent:", end="", flush=True) |
||||
|
# Extract and print the content |
||||
|
# print(content, end="", flush=True) |
||||
|
print(content) |
||||
|
yield content |
||||
|
|
||||
|
|
||||
|
async def handle_websocket(websocket): |
||||
|
print("客户端已连接") |
||||
|
try: |
||||
|
async for message in websocket: |
||||
|
try: |
||||
|
data = json.loads(message) |
||||
|
texts = data.get("texts") |
||||
|
title = data.get("title") |
||||
|
top = data.get("top") |
||||
|
print(f"收到消息: {texts}") |
||||
|
|
||||
|
# 获取响应 |
||||
|
response = main(texts, title, top) |
||||
|
|
||||
|
# 发送响应 |
||||
|
for char in response: |
||||
|
await websocket.send(char) |
||||
|
await asyncio.sleep(0.001) # 小延迟避免发送过快 |
||||
|
|
||||
|
# 发送完成标记 |
||||
|
await websocket.send("[DONE]") |
||||
|
print("消息发送完成") |
||||
|
|
||||
|
except json.JSONDecodeError: |
||||
|
await websocket.send('{"error": "Invalid JSON format"}') |
||||
|
except Exception as e: |
||||
|
print(f"处理消息时发生错误: {e}") |
||||
|
await websocket.send('{"error": "Internal server error"}') |
||||
|
|
||||
|
except websockets.exceptions.ConnectionClosed: |
||||
|
print("客户端断开连接") |
||||
|
except Exception as e: |
||||
|
print(f"WebSocket处理异常: {e}") |
||||
|
finally: |
||||
|
print("连接处理结束") |
||||
|
|
||||
|
async def main_api(): |
||||
|
try: |
||||
|
ssl_context = None |
||||
|
|
||||
|
# 检查证书文件是否存在 |
||||
|
|
||||
|
ssl_cert = "yitongtang66.com.crt" |
||||
|
ssl_key = "yitongtang66.com.key" |
||||
|
|
||||
|
# ssl_cert = "yizherenxin.cn.crt" |
||||
|
# ssl_key = "yizherenxin.cn.key" |
||||
|
|
||||
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) |
||||
|
ssl_context.load_cert_chain(ssl_cert, ssl_key) |
||||
|
ssl_context.check_hostname = False # 必须禁用主机名验证 |
||||
|
ssl_context.verify_mode = ssl.CERT_NONE # 不验证证书 |
||||
|
|
||||
|
# if os.path.exists(ssl_cert) and os.path.exists(ssl_key): |
||||
|
# try: |
||||
|
# # 创建SSL上下文 |
||||
|
# ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) |
||||
|
# # 加载证书链 |
||||
|
# ssl_context.load_cert_chain(ssl_cert, ssl_key) |
||||
|
# # 禁用主机名验证(对于自签名证书) |
||||
|
# ssl_context.check_hostname = False |
||||
|
# ssl_context.verify_mode = ssl.CERT_NONE |
||||
|
# print("SSL证书已加载,使用WSS协议") |
||||
|
# except Exception as e: |
||||
|
# print(f"SSL证书加载失败: {e}") |
||||
|
# print("将使用WS协议") |
||||
|
# ssl_context = None |
||||
|
# else: |
||||
|
# print("警告:SSL证书文件未找到,将使用WS协议") |
||||
|
# ssl_context = None |
||||
|
|
||||
|
# 创建服务器 |
||||
|
server = await websockets.serve( |
||||
|
handle_websocket, |
||||
|
"0.0.0.0", |
||||
|
27001, |
||||
|
ssl=ssl_context, |
||||
|
ping_interval=30, # 添加ping间隔 |
||||
|
ping_timeout=30, # 添加ping超时 |
||||
|
close_timeout=30 # 添加关闭超时 |
||||
|
) |
||||
|
|
||||
|
if ssl_context: |
||||
|
print("WSS服务器已启动: wss://0.0.0.0:27001") |
||||
|
else: |
||||
|
print("WS服务器已启动: ws://0.0.0.0:27001") |
||||
|
|
||||
|
# 保持服务器运行 |
||||
|
await server.wait_closed() |
||||
|
|
||||
|
except Exception as e: |
||||
|
print(f"服务器启动失败: {e}") |
||||
|
import traceback |
||||
|
traceback.print_exc() |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
# 设置更详细的事件循环调试 |
||||
|
import logging |
||||
|
|
||||
|
logging.basicConfig(level=logging.INFO) |
||||
|
|
||||
|
try: |
||||
|
asyncio.run(main_api()) |
||||
|
except KeyboardInterrupt: |
||||
|
print("服务器被用户中断") |
||||
|
except Exception as e: |
||||
|
print(f"服务器运行错误: {e}") |
Loading…
Reference in new issue