You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
2.6 KiB
85 lines
2.6 KiB
from openai import OpenAI
|
|
|
|
openai_api_key = "token-abc123"
|
|
openai_api_base = "http://127.0.0.1:12011/v1"
|
|
|
|
client = OpenAI(
|
|
api_key=openai_api_key,
|
|
base_url=openai_api_base,
|
|
)
|
|
|
|
models = client.models.list()
|
|
model = models.data[0].id
|
|
|
|
|
|
def model_generate_stream(prompt):
|
|
messages = [
|
|
{"role": "user", "content": prompt}
|
|
]
|
|
|
|
stream = client.chat.completions.create(model=model,
|
|
messages=messages,
|
|
stream=True)
|
|
printed_reasoning_content = False
|
|
printed_content = False
|
|
|
|
for chunk in stream:
|
|
reasoning_content = None
|
|
content = None
|
|
# Check the content is reasoning_content or content
|
|
if hasattr(chunk.choices[0].delta, "reasoning_content"):
|
|
reasoning_content = chunk.choices[0].delta.reasoning_content
|
|
elif hasattr(chunk.choices[0].delta, "content"):
|
|
content = chunk.choices[0].delta.content
|
|
|
|
if reasoning_content is not None:
|
|
if not printed_reasoning_content:
|
|
printed_reasoning_content = True
|
|
print("reasoning_content:", end="", flush=True)
|
|
print(reasoning_content, end="", flush=True)
|
|
elif content is not None:
|
|
if not printed_content:
|
|
printed_content = True
|
|
print("\ncontent:", end="", flush=True)
|
|
# Extract and print the content
|
|
# print(content, end="", flush=True)
|
|
print(content)
|
|
yield content
|
|
# if __name__ == '__main__':
|
|
# for i in model_generate_stream("你好"):
|
|
# print(i)
|
|
|
|
|
|
import asyncio
|
|
import websockets
|
|
import json
|
|
|
|
|
|
async def handle_websocket(websocket):
|
|
print("客户端已连接")
|
|
try:
|
|
while True:
|
|
message = await websocket.recv()
|
|
print("收到消息:", message)
|
|
|
|
data = json.loads(message)
|
|
texts = data.get("texts")
|
|
title = data.get("title")
|
|
top = data.get("top")
|
|
|
|
response = model_generate_stream(texts)
|
|
# response = message + "111"
|
|
for char in response:
|
|
await websocket.send(char)
|
|
# await asyncio.sleep(0.3)
|
|
await websocket.send("[DONE]")
|
|
except websockets.exceptions.ConnectionClosed:
|
|
print("客户端断开连接")
|
|
|
|
async def main():
|
|
async with websockets.serve(handle_websocket, "0.0.0.0", 5500):
|
|
print("WebSocket 服务器已启动,监听端口 5500")
|
|
await asyncio.Future() # 永久运行
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) # 正确启动事件循环
|