diff --git a/generate_references_api.py b/generate_references_api_1.py similarity index 74% rename from generate_references_api.py rename to generate_references_api_1.py index 8b2b9ef..451c874 100644 --- a/generate_references_api.py +++ b/generate_references_api_1.py @@ -70,7 +70,17 @@ def dialog_line_parse(url, text): print(text) return [] -def ulit_recall_paper(reference_list): + +def panduan_paper_lable(paper_lable_text): + paper_lable = { + "硕士": "D", + "期刊": "J", + "博士": "J" + } + return paper_lable[paper_lable_text] + + +def ulit_recall_paper(reference_list, nums): ''' 对返回的十篇文章路径读取并解析 :param recall_data_list_path: @@ -89,13 +99,16 @@ def ulit_recall_paper(reference_list): # return data # recall_data_list - # 作者 论文名称 论文类别 论文来源 论文年份 摘要 + # 作者 论文名称 论文类别 论文来源 论文年份 摘要 期刊 # "[1]赵璐.基于旅游资源开发下的新农村景观营建研究[D].西安建筑科技大学,2014." + data = [] for data_one in reference_list: + print("data_one", data_one) + print("data_one[0]", data_one[0]) paper = ".".join([ - ",".join([i for i in data_one[0].split(";") if i != ""]), - data_one[1] + "[J]", + ",".join([str(i).replace("\n", "").replace("\r", "") for i in data_one[0].split(";") if i != ""]), + data_one[1] + f"[{panduan_paper_lable(data_one[6])}]", ",".join([ data_one[3], str(data_one[4]) + "." ]) @@ -103,6 +116,12 @@ def ulit_recall_paper(reference_list): data.append(paper) + # print(data) + data = list(set(data)) + print(len(data)) + print(data[0]) + print(nums) + data = data[:int(nums)] return data @@ -128,15 +147,35 @@ def main(title, abstract, nums): # zidonghua = np.load('data/prompt/{subject_pinyin}.npy') data_subject = np.load(f"data/prompt_qikan/{subject_pinyin}.npy") + data_subject_1 = np.load(f"data/prompt_master/{subject_pinyin}.npy") + data_subject_2 = np.load(f"data/prompt_doctor/{subject_pinyin}.npy") + print("xb.shape", data_subject.shape) + print("xb_1.shape", data_subject_1.shape) + print("xb_2.shape", data_subject_2.shape) + data_subject = np.concatenate((data_subject, data_subject_1, data_subject_2)) + print("data_subject.shape", data_subject.shape) - index = faiss.read_index(f'data/prompt_qikan_ivf/{subject_pinyin}.ivf') + index = faiss.read_index(f'data/prompt_qikan_master_doctor_ivf/{subject_pinyin}.ivf') with open(f"data/data_info_qikan/{subject_pinyin}.json") as f: data_info = json.loads(f.read()) + with open(f"data/data_info_master/{subject_pinyin}.json") as f: + data_info_1 = json.loads(f.read()) + + with open(f"data/data_info_doctor/{subject_pinyin}.json") as f: + data_info_2 = json.loads(f.read()) + + print(len(data_info)) + print(len(data_info_1)) + print(len(data_info_2)) + data_info = data_info + data_info_1 + data_info_2 + print(len(data_info)) + print(data_info[0]) index.add(data_subject) # index.nprobe = 2 # default nprobe is 1, try a few more - k = nums + # k = nums + k = 50 prompt = "标题:“{}”,摘要:“{}”".format(title, abstract) embs = model.encode([prompt], normalize_embeddings=True) @@ -149,7 +188,7 @@ def main(title, abstract, nums): reference_list.append(data_info[i]) abstract_list.append(data_info[i][5]) - return "200", ulit_recall_paper(reference_list), abstract_list + return "200", ulit_recall_paper(reference_list, nums), abstract_list @app.route("/", methods=["POST"]) @@ -171,10 +210,7 @@ def handle_query(): reference = [f"[{str(i+1)}]" + reference_list[i] for i in range(len(reference_list))] if status_code == "200": return_text = { - "resilt": { - "reference": reference, - "abstract": abstract_list - }, + "resilt": reference, "probabilities": None, "status_code": 200 }