参考文献生成项目,使用faiss实现
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

134 lines
5.3 KiB

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import json
import numpy as np
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import re
model = SentenceTransformer('Dmeta-embedding-zh')
print(1)
with open("data/discipline_types.json", encoding="utf-8") as f:
lable_discipline_types = json.loads(f.read())
def erjimul_ulit():
pass
def shengcehng_array(data):
embs = model.encode(data, normalize_embeddings=True)
return embs
def is_contain_chinese(word):
"""
判断字符串是否包含中文字符
:param word: 字符串
:return: 布尔值,True表示包含中文,False表示不包含中文
"""
pattern = re.compile(r'[\u4e00-\u9fa5]')
match = pattern.search(word)
return True if match else False
if __name__ == '__main__':
# data = []
with open("data/data_0423_qikan.json", encoding="utf-8") as f:
# for i in f.readlines():
# a = json.loads(i)
# data.append(a)
data = json.loads(f.read())
print(len(data))
a = 0
a_ = 0
data_info = {} # 作者 论文名称 论文类别 论文来源 论文年份 摘要
data_prompt = {}
data_info_en = {} # 作者 论文名称 论文类别 论文来源 论文年份 摘要
data_prompt_en = {}
for data_dan in data:
if str(data_dan["special_topic"]) == "nan" or \
str(data_dan["author"]) == "nan" or \
str(data_dan["title"]) == "nan" or \
str(data_dan["qikan_name"]) == "nan" or \
str(data_dan["year"]) == "nan" or \
str(data_dan["abstract"]) == "nan":
a_ += 1
continue
leibie_list = data_dan["special_topic"].split(";")
for leibie in leibie_list:
if leibie in lable_discipline_types:
zh_bool = is_contain_chinese(data_dan["title"])
if zh_bool == True:
if lable_discipline_types[leibie] not in data_prompt:
dan_data_prompt = "标题:“{}”,摘要:“{}".format(data_dan["title"], data_dan["abstract"])
data_prompt[lable_discipline_types[leibie]] = [dan_data_prompt]
data_info[lable_discipline_types[leibie]] = [
[data_dan["author"], data_dan["title"], data_dan["special_topic"], data_dan["qikan_name"],
data_dan["year"], data_dan["abstract"], "期刊"]]
else:
dan_data_prompt = "标题:“{}”,摘要:“{}".format(data_dan["title"], data_dan["abstract"])
data_prompt[lable_discipline_types[leibie]].append(dan_data_prompt)
data_info[lable_discipline_types[leibie]].append(
[data_dan["author"], data_dan["title"], data_dan["special_topic"], data_dan["qikan_name"],
data_dan["year"], data_dan["abstract"], "期刊"])
else:
if lable_discipline_types[leibie] not in data_prompt_en:
dan_data_prompt = "标题:“{}”,摘要:“{}".format(data_dan["title"], data_dan["abstract"])
data_prompt_en[lable_discipline_types[leibie]] = [dan_data_prompt]
data_info_en[lable_discipline_types[leibie]] = [
[data_dan["author"], data_dan["title"], data_dan["special_topic"], data_dan["qikan_name"],
data_dan["year"], data_dan["abstract"], "期刊"]]
else:
dan_data_prompt = "标题:“{}”,摘要:“{}".format(data_dan["title"], data_dan["abstract"])
data_prompt_en[lable_discipline_types[leibie]].append(dan_data_prompt)
data_info_en[lable_discipline_types[leibie]].append(
[data_dan["author"], data_dan["title"], data_dan["special_topic"], data_dan["qikan_name"],
data_dan["year"], data_dan["abstract"], "期刊"])
a += 1
print(2)
strat = 0
end = 10000
print(len(data_prompt))
for leibie in tqdm(data_prompt):
data_ndarray = np.empty((0, 768))
print("len(data_prompt[leibie])", len(data_prompt[leibie]))
while True:
if end >= len(data_prompt[leibie]):
break
linshi_data = data_prompt[leibie][strat:end]
data_ndarray = np.concatenate((data_ndarray, shengcehng_array(linshi_data)))
print("data_ndarray.shape", data_ndarray.shape)
strat = end
end += 10000
linshi_data = data_prompt[leibie][strat:len(data_prompt[leibie])]
print("len(linshi_data)", len(linshi_data))
data_ndarray = np.concatenate((data_ndarray, shengcehng_array(linshi_data)))
print("data_ndarray.shape", data_ndarray.shape)
np.save(f'data/prompt_qikan/{leibie}.npy', data_ndarray)
strat = 0
end = 10000
for leibie in data_info:
print(len(data_info[leibie]))
with open(f"data/data_info_qikan/{leibie}.json", "w", encoding="utf-8") as f:
f.write(json.dumps(data_info[leibie], ensure_ascii=False, indent=2))
for i in data_prompt_en:
print(i)
print(len(data_prompt_en[i]))
print(len(data))
print(a_)