From 823923c92761f82d710d3ef00e71facc84a988ef Mon Sep 17 00:00:00 2001
From: "majiahui@haimaqingfan.com" <majiahui@haimaqingfan.com>
Date: Mon, 14 Apr 2025 17:16:19 +0800
Subject: [PATCH] =?UTF-8?q?=E5=8D=95=E6=9D=A1=E6=95=B0=E6=8D=AE=E4=B8=8A?=
 =?UTF-8?q?=E4=BC=A0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 main.py | 341 ++++++++++++++++++++++++++++++++++++++++++++++------------------
 1 file changed, 248 insertions(+), 93 deletions(-)

diff --git a/main.py b/main.py
index 306b1bf..87f8823 100644
--- a/main.py
+++ b/main.py
@@ -2,21 +2,37 @@
 
 # 按 Shift+F10 执行或将其替换为您的代码。
 # 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
+import os
 import faiss
 import numpy as np
 from tqdm import tqdm
 from sentence_transformers import SentenceTransformer
 import requests
 import time
-from flask import Flask, jsonify
-from flask import request
+from flask import Flask, jsonify, Response, request
+from openai import OpenAI
+from flask_cors import CORS
 import pandas as pd
+import concurrent.futures
+import json
 
 
 app = Flask(__name__)
+CORS(app)
 app.config["JSON_AS_ASCII"] = False
 
-model = SentenceTransformer('/home/majiahui/project/models-llm/bge-large-zh-v1.5')
+openai_api_key = "token-abc123"
+openai_api_base = "http://127.0.0.1:12011/v1"
+
+client = OpenAI(
+    api_key=openai_api_key,
+    base_url=openai_api_base,
+)
+
+models = client.models.list()
+# model = models.data[0].id
+model = "1"
+model_encode = SentenceTransformer('/home/majiahui/project/models-llm/bge-large-zh-v1.5')
 propmt_connect = '''我是一名中医,你是一个中医的医生的助理,我的患者有一个症状,症状如下:
 {}
 根据这些症状,我通过查找资料,{}
@@ -26,7 +42,7 @@ propmt_connect_ziliao = '''在“{}”资料中,有如下相关内容:
 {}'''
 
 
-def dialog_line_parse(url, text):
+def dialog_line_parse(text):
     """
     将数据输入模型进行分析并输出结果
     :param url: 模型url
@@ -34,8 +50,9 @@ def dialog_line_parse(url, text):
     :return: 模型返回结果
     """
 
+    url_predict = "http://118.178.228.101:12004/predict"
     response = requests.post(
-        url,
+        url_predict,
         json=text,
         timeout=100000
     )
@@ -49,46 +66,167 @@ def dialog_line_parse(url, text):
         # )
         print("【{}】 Failed to get a proper response from remote "
               "server. Status Code: {}. Response: {}"
-              "".format(url, response.status_code, response.text))
+              "".format(url_predict, response.status_code, response.text))
         return {}
 
+    # ['choices'][0]['message']['content']
+    #
+    # text = text['messages'][0]['content']
+    # return_text = {
+    #     'code': 200,
+    #     'id': "1",
+    #     'object': 0,
+    #     'created': 0,
+    #     'model': 0,
+    #     'choices': [
+    #         {
+    #             'index': 0,
+    #             'message': {
+    #                 'role': 'assistant',
+    #                 'content': text
+    #             },
+    #             'logprobs': None,
+    #             'finish_reason': 'stop'
+    #         }
+    #     ],
+    #     'usage': 0,
+    #     'system_fingerprint': 0
+    # }
+    # return return_text
+
 
 def shengcehng_array(data):
-    embs = model.encode(data, normalize_embeddings=True)
+    embs = model_encode.encode(data, normalize_embeddings=True)
     return embs
 
-def Building_vector_database(type, name, df):
-    data_ndarray = np.empty((0, 1024))
-    for sen in df:
-        data_ndarray = np.concatenate((data_ndarray, shengcehng_array([sen[0]])))
-        print("data_ndarray.shape", data_ndarray.shape)
 
-    print("data_ndarray.shape", data_ndarray.shape)
-    np.save(f'data_np/{name}.npy', data_ndarray)
 
+def Building_vector_database(title, df):
+    # 加载需要处理的数据(有效且未向量化)
+    to_process = df[(df["有效"] == True) & (df["已向量化"] == False)]
 
+    if len(to_process) == 0:
+        print("无新增数据需要向量化")
+        return
 
-def ulit_request_file(file, title):
-    file_name = file.filename
-    file_name_save = "data_file/{}.csv".format(title)
-    file.save(file_name_save)
+    # 生成向量数组
+    new_vectors = shengcehng_array(to_process["总结"].tolist())  # 假设这是你的向量生成函数
 
-    # try:
-    #     with open(file_name_save, encoding="gbk") as f:
-    #         content = f.read()
-    # except:
-    #     with open(file_name_save, encoding="utf-8") as f:
-    #         content = f.read()
-    # elif file_name.split(".")[-1] == "docx":
-    #     content = docx2txt.process(file_name_save)
+    # 加载现有向量库和索引
+    vector_path = f"data_np/{title}.npy"
+    index_path = f"data_np/{title}_index.json"
 
-    # content_list = [i for i in content.split("\n")]
-    df = pd.read_csv(file_name_save, sep="\t", encoding="utf-8").values.tolist()
+    vectors = np.load(vector_path) if os.path.exists(vector_path) else np.empty((0, 1024))
+    index_data = {}
+    if os.path.exists(index_path):
+        with open(index_path, "r") as f:
+            index_data = json.load(f)
 
-    return df
+    # 更新索引和向量库
+    start_idx = len(vectors)
+    vectors = np.vstack([vectors, new_vectors])
+
+    for i, (_, row) in enumerate(to_process.iterrows()):
+        index_data[row["ID"]] = {
+            "row": start_idx + i,
+            "valid": True
+        }
+
+    # 保存数据
+    np.save(vector_path, vectors)
+    with open(index_path, "w") as f:
+        json.dump(index_data, f)
 
+    # 标记已向量化
+    df.loc[to_process.index, "已向量化"] = True
+    df.to_csv(f"data_file_res/{title}.csv", sep="\t", index=False)
+
+
+def delete_data(title, data_id):
+    # 更新CSV标记
+    csv_path = f"data_file_res/{title}.csv"
+    df = pd.read_csv(csv_path, sep="\t", dtype={"ID": str})
+    df.loc[df["ID"] == data_id, "有效"] = False
+    df.to_csv(csv_path, sep="\t", index=False)
+
+    # 更新索引标记
+    index_path = f"data_np/{title}_index.json"
+    if os.path.exists(index_path):
+        with open(index_path, "r+") as f:
+            index_data = json.load(f)
+            if data_id in index_data:
+                index_data[data_id]["valid"] = False
+                f.seek(0)
+                json.dump(index_data, f)
+                f.truncate()
+
+
+def check_file_exists(file_path):
+    """
+    检查文件是否存在
 
-def main(question, db_type, top):
+    参数:
+        file_path (str): 要检查的文件路径
+
+    返回:
+        bool: 文件存在返回True,否则返回False
+    """
+    return os.path.isfile(file_path)
+
+
+def ulit_request_file(new_id, sentence, title):
+    file_name_res_save = f"data_file_res/{title}.csv"
+
+    # 初始化或读取CSV文件
+    if os.path.exists(file_name_res_save):
+        df = pd.read_csv(file_name_res_save, sep="\t")
+        # 检查是否已存在相同正文
+        if sentence in df["正文"].values:
+            print("正文已存在,跳过处理")
+            return df
+    else:
+        df = pd.DataFrame(columns=["ID", "正文", "总结", "有效", "已向量化"])
+
+    # 添加新数据(生成唯一ID)
+    new_row = {
+        "ID": str(new_id),
+        "正文": sentence,
+        "总结": None,
+        "有效": True,
+        "已向量化": False
+    }
+    df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
+
+    # 筛选需要处理的记录
+    to_process = df[(df["总结"].isna()) & (df["有效"] == True)]
+
+    # 调用API生成总结(示例保留原有逻辑)
+    data_list = []
+    for _, row in to_process.iterrows():
+        data_list.append({
+            "model": "gpt-4-turbo",
+            "messages": [{
+                "role": "user",
+                "content": f"{row['正文']}\n以上这条中可能包含了一些病情或者症状,请帮我归纳这条中所对应的病情或者症状是哪些,总结出来,不需要很长,简单归纳即可,直接输出症状或者病情,可以包含一些形容词来辅助描述,不需要有辅助词汇"
+            }],
+            "top_p": 0.9,
+            "temperature": 0.6
+        })
+
+    # 并发处理请求
+    with concurrent.futures.ThreadPoolExecutor(200) as executor:
+        results = list(executor.map(dialog_line_parse, data_list))
+
+    # 更新总结字段
+    for idx, result in zip(to_process.index, results):
+        summary = result['choices'][0]['message']['content']
+        df.at[idx, "总结"] = summary
+
+    # 保存更新后的CSV
+    df.to_csv(file_name_res_save, sep="\t", index=False)
+    return df
+
+def main(question, title, top):
     db_dict = {
         "1": "yetianshi"
     }
@@ -114,30 +252,43 @@ def main(question, db_type, top):
     根据提问匹配上下文
     '''
     d = 1024
-    db_type_list = db_type.split(",")
+    db_type_list = title.split(",")
 
     paper_list_str = ""
-    for i in db_type_list:
+    for title_dan in db_type_list:
         embs = shengcehng_array([question])
         index = faiss.IndexFlatIP(d)  # buid the index
-        data_np = np.load(f"data_np/{i}.npy")
-        # data_str = open(f"data_file/{i}.txt").read().split("\n")
-        data_str = pd.read_csv(f"data_file/{i}.csv", sep="\t", encoding="utf-8").values.tolist()
-        index.add(data_np)
+
+        # 查找向量
+        vector_path = f"data_np/{title_dan}.npy"
+        index_path = f"data_np/{title_dan}_index.json"
+
+        if not os.path.exists(vector_path) or not os.path.exists(index_path):
+            return np.empty((0, 1024))
+
+        vectors = np.load(vector_path)
+        with open(index_path, "r") as f:
+            index_data = json.load(f)
+
+        data_str = pd.read_csv(f"data_file_res/{title_dan}.csv", sep="\t", encoding="utf-8").values.tolist()
+        index.add(vectors)
         D, I = index.search(embs, int(top))
         print(I)
 
         reference_list = []
         for i,j in zip(I[0], D[0]):
-            reference_list.append([data_str[i], j])
+            if data_str[i][3] == True:
+                reference_list.append([data_str[i], j])
 
         for i,j in enumerate(reference_list):
-            paper_list_str += "第{}篇\n{},此篇文章的转发数为{},评论数为{},点赞数为{}\n,此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][0], j[0][1], j[0][2], j[0][3], j[1])
+            paper_list_str += "第{}篇\n{},此篇文章跟问题的相关度为{}%\n".format(str(i+1), j[0][1], j[1])
+
 
     '''
     构造prompt
     '''
     print("paper_list_str", paper_list_str)
+    9/0
     propmt_connect_ziliao_input = []
     for i in db_type_list:
         propmt_connect_ziliao_input.append(propmt_connect_ziliao.format(i, paper_list_str))
@@ -147,61 +298,70 @@ def main(question, db_type, top):
     '''
     生成回答
     '''
-    url_predict = "http://192.168.31.74:26000/predict"
-    url_search = "http://192.168.31.74:26000/search"
-
-    # data = {
-    #     "content": propmt_connect_input
-    # }
-    data = {
-        "content": propmt_connect_input,
-        "model": "qwq-32",
-        "top_p": 0.9,
-        "temperature": 0.6
-    }
-
-    res = dialog_line_parse(url_predict, data)
-    id_ = res["texts"]["id"]
-
-    data = {
-        "id": id_
-    }
-
-    while True:
-        res = dialog_line_parse(url_search, data)
-        if res["code"] == 200:
-            break
-        else:
-            time.sleep(1)
-    spilt_str = "</think>"
-    think, response = str(res["text"]).split(spilt_str)
-    return think, response
-
-
-@app.route("/upload_file", methods=["POST"])
-def upload_file():
-    print(request.remote_addr)
-    file = request.files.get('file')
-    title = request.form.get("title")
-    df = ulit_request_file(file, title)
-    Building_vector_database("1", title, df)
-    return_json = {
-        "code": 200,
-        "info": "上传完成"
-    }
-    return jsonify(return_json)  # 返回结果
+    return model_generate_stream(propmt_connect_input)
+
+
+def model_generate_stream(prompt):
+    messages = [
+        {"role": "user", "content": prompt}
+    ]
+
+    stream = client.chat.completions.create(model=model,
+                                            messages=messages,
+                                            stream=True)
+    printed_reasoning_content = False
+    printed_content = False
+
+    for chunk in stream:
+        reasoning_content = None
+        content = None
+        # Check the content is reasoning_content or content
+        if hasattr(chunk.choices[0].delta, "reasoning_content"):
+            reasoning_content = chunk.choices[0].delta.reasoning_content
+        elif hasattr(chunk.choices[0].delta, "content"):
+            content = chunk.choices[0].delta.content
+
+        if reasoning_content is not None:
+            if not printed_reasoning_content:
+                printed_reasoning_content = True
+                print("reasoning_content:", end="", flush=True)
+            print(reasoning_content, end="", flush=True)
+        elif content is not None:
+            if not printed_content:
+                printed_content = True
+                print("\ncontent:", end="", flush=True)
+            # Extract and print the content
+            # print(content, end="", flush=True)
+            print(content)
+            yield content
 
 
 @app.route("/upload_file_check", methods=["POST"])
 def upload_file_check():
     print(request.remote_addr)
-    file = request.files.get('file')
+    sentence = request.form.get('sentence')
     title = request.form.get("title")
-    df = ulit_request_file(file, title)
-    Building_vector_database("1", title, df)
+    new_id = request.form.get("id")
+    state = request.form.get("state")
+    '''
+        {
+            "1": "csv",
+            "2": "xlsx",
+            "3": "txt",
+            "4": "pdf"
+        }
+    '''
+    state_res = ""
+    if state == "1":
+        df = ulit_request_file(new_id, sentence, title)
+        Building_vector_database(title, df)
+        state_res = "上传完成"
+    elif state == "2":
+        delete_data(title, new_id)
+        state_res = "删除完成"
     return_json = {
         "code": 200,
-        "info": "上传完成"
+        "info": state_res
     }
     return jsonify(return_json)  # 返回结果
 
@@ -210,15 +370,10 @@ def upload_file_check():
 def search():
     print(request.remote_addr)
     texts = request.json["texts"]
-    text_type = request.json["text_type"]
+    title = request.json["title"]
     top = request.json["top"]
-    think, response = main(texts, text_type, top)
-    return_json = {
-        "code": 200,
-        "think": think,
-        "response": response
-    }
-    return jsonify(return_json)  # 返回结果
+    response = main(texts, title, top)
+    return Response(response, mimetype='text/plain; charset=utf-8')  # 返回结果
 
 
 if __name__ == "__main__":