
1 changed files with 241 additions and 0 deletions
@ -0,0 +1,241 @@ |
|||
# 这是一个示例 Python 脚本。 |
|||
|
|||
# 按 Shift+F10 执行或将其替换为您的代码。 |
|||
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 |
|||
import os |
|||
import faiss |
|||
import numpy as np |
|||
from tqdm import tqdm |
|||
from sentence_transformers import SentenceTransformer |
|||
import requests |
|||
import time |
|||
from flask import Flask, jsonify, Response, request |
|||
from openai import OpenAI |
|||
from flask_cors import CORS |
|||
import pandas as pd |
|||
import concurrent.futures |
|||
import json |
|||
from threading import Thread |
|||
import redis |
|||
import asyncio |
|||
import websockets |
|||
|
|||
app = Flask(__name__) |
|||
CORS(app) |
|||
app.config["JSON_AS_ASCII"] = False |
|||
|
|||
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=1, password="zhicheng123*") |
|||
redis_ = redis.Redis(connection_pool=pool, decode_responses=True) |
|||
|
|||
db_key_query = 'query' |
|||
db_key_querying = 'querying' |
|||
batch_size = 32 |
|||
|
|||
openai_api_key = "token-abc123" |
|||
openai_api_base = "http://127.0.0.1:12011/v1" |
|||
|
|||
client = OpenAI( |
|||
api_key=openai_api_key, |
|||
base_url=openai_api_base, |
|||
) |
|||
|
|||
models = client.models.list() |
|||
model = models.data[0].id |
|||
# model = "1" |
|||
model_encode = SentenceTransformer('/home/majiahui/project/models-llm/bge-large-zh-v1.5') |
|||
propmt_connect = '''我是一名中医,你是一个中医的医生的助理,我的患者有一个症状,症状如下: |
|||
{} |
|||
根据这些症状,我通过查找资料,{} |
|||
请根据上面的这些资料和方子,并根据每篇文章的转发数确定文章的重要程度,转发数越高的文章,最终答案的参考度越高,反之越低。根据患者的症状和上面的文章的资料的重要程度以及文章和症状的匹配程度,帮我开出正确的药方和治疗方案''' |
|||
|
|||
propmt_connect_ziliao = '''在“{}”资料中,有如下相关内容: |
|||
{}''' |
|||
|
|||
|
|||
def dialog_line_parse(text): |
|||
""" |
|||
将数据输入模型进行分析并输出结果 |
|||
:param url: 模型url |
|||
:param text: 进入模型的数据 |
|||
:return: 模型返回结果 |
|||
""" |
|||
|
|||
url_predict = "http://118.178.228.101:12004/predict" |
|||
response = requests.post( |
|||
url_predict, |
|||
json=text, |
|||
timeout=100000 |
|||
) |
|||
if response.status_code == 200: |
|||
return response.json() |
|||
else: |
|||
# logger.error( |
|||
# "【{}】 Failed to get a proper response from remote " |
|||
# "server. Status Code: {}. Response: {}" |
|||
# "".format(url, response.status_code, response.text) |
|||
# ) |
|||
print("【{}】 Failed to get a proper response from remote " |
|||
"server. Status Code: {}. Response: {}" |
|||
"".format(url_predict, response.status_code, response.text)) |
|||
return {} |
|||
|
|||
# ['choices'][0]['message']['content'] |
|||
# |
|||
# text = text['messages'][0]['content'] |
|||
# return_text = { |
|||
# 'code': 200, |
|||
# 'id': "1", |
|||
# 'object': 0, |
|||
# 'created': 0, |
|||
# 'model': 0, |
|||
# 'choices': [ |
|||
# { |
|||
# 'index': 0, |
|||
# 'message': { |
|||
# 'role': 'assistant', |
|||
# 'content': text |
|||
# }, |
|||
# 'logprobs': None, |
|||
# 'finish_reason': 'stop' |
|||
# } |
|||
# ], |
|||
# 'usage': 0, |
|||
# 'system_fingerprint': 0 |
|||
# } |
|||
# return return_text |
|||
|
|||
|
|||
def shengcehng_array(data): |
|||
embs = model_encode.encode(data, normalize_embeddings=True) |
|||
return embs |
|||
|
|||
|
|||
def main(question, title, top): |
|||
db_dict = { |
|||
"1": "yetianshi" |
|||
} |
|||
''' |
|||
定义文件路径 |
|||
''' |
|||
|
|||
''' |
|||
加载文件 |
|||
''' |
|||
|
|||
''' |
|||
文本分割 |
|||
''' |
|||
|
|||
''' |
|||
构建向量数据库 |
|||
1. 正常匹配 |
|||
2. 把文本使用大模型生成一个问题之后再进行匹配 |
|||
''' |
|||
|
|||
''' |
|||
根据提问匹配上下文 |
|||
''' |
|||
d = 1024 |
|||
db_type_list = title.split(",") |
|||
|
|||
paper_list_str = "" |
|||
for title_dan in db_type_list: |
|||
embs = shengcehng_array([question]) |
|||
index = faiss.IndexFlatIP(d) # buid the index |
|||
|
|||
# 查找向量 |
|||
vector_path = f"data_np/{title_dan}.npy" |
|||
vectors = np.load(vector_path) |
|||
|
|||
data_str = pd.read_csv(f"data_file/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist() |
|||
index.add(vectors) |
|||
D, I = index.search(embs, int(top)) |
|||
print(I) |
|||
|
|||
reference_list = [] |
|||
for i, j in zip(I[0], D[0]): |
|||
reference_list.append([data_str[i], j]) |
|||
|
|||
for i, j in enumerate(reference_list): |
|||
paper_list_str += "第{}篇\n{},此篇文章跟问题的相关度为{}%\n".format(str(i + 1), j[0][1], j[1]) |
|||
|
|||
|
|||
''' |
|||
构造prompt |
|||
''' |
|||
print("paper_list_str", paper_list_str) |
|||
propmt_connect_ziliao_input = [] |
|||
for i in db_type_list: |
|||
propmt_connect_ziliao_input.append(propmt_connect_ziliao.format(i, paper_list_str)) |
|||
|
|||
propmt_connect_ziliao_input_str = ",".join(propmt_connect_ziliao_input) |
|||
propmt_connect_input = propmt_connect.format(question, propmt_connect_ziliao_input_str) |
|||
''' |
|||
生成回答 |
|||
''' |
|||
return model_generate_stream(propmt_connect_input) |
|||
|
|||
|
|||
def model_generate_stream(prompt): |
|||
messages = [ |
|||
{"role": "user", "content": prompt} |
|||
] |
|||
|
|||
stream = client.chat.completions.create(model=model, |
|||
messages=messages, |
|||
stream=True) |
|||
printed_reasoning_content = False |
|||
printed_content = False |
|||
|
|||
for chunk in stream: |
|||
reasoning_content = None |
|||
content = None |
|||
# Check the content is reasoning_content or content |
|||
if hasattr(chunk.choices[0].delta, "reasoning_content"): |
|||
reasoning_content = chunk.choices[0].delta.reasoning_content |
|||
elif hasattr(chunk.choices[0].delta, "content"): |
|||
content = chunk.choices[0].delta.content |
|||
|
|||
if reasoning_content is not None: |
|||
if not printed_reasoning_content: |
|||
printed_reasoning_content = True |
|||
print("reasoning_content:", end="", flush=True) |
|||
print(reasoning_content, end="", flush=True) |
|||
elif content is not None: |
|||
if not printed_content: |
|||
printed_content = True |
|||
print("\ncontent:", end="", flush=True) |
|||
# Extract and print the content |
|||
# print(content, end="", flush=True) |
|||
print(content) |
|||
yield content |
|||
|
|||
|
|||
async def handle_websocket(websocket): |
|||
print("客户端已连接") |
|||
try: |
|||
while True: |
|||
message = await websocket.recv() |
|||
data = json.loads(message) |
|||
texts = data.get("texts") |
|||
title = data.get("title") |
|||
top = data.get("top") |
|||
print("收到消息:", message) |
|||
|
|||
response = main(texts, title, top) |
|||
# response = message + "111" |
|||
for char in response: |
|||
await websocket.send(char) |
|||
# await asyncio.sleep(0.3) |
|||
await websocket.send("[DONE]") |
|||
except websockets.exceptions.ConnectionClosed: |
|||
print("客户端断开连接") |
|||
|
|||
|
|||
async def main_api(): |
|||
async with websockets.serve(handle_websocket, "0.0.0.0", 27001): |
|||
print("WebSocket 服务器已启动,监听端口 27001") |
|||
await asyncio.Future() # 永久运行 |
|||
|
|||
if __name__ == "__main__": |
|||
asyncio.run(main_api()) # 正确启动事件循环 |
Loading…
Reference in new issue