参考文献生成项目,使用faiss实现
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

40 lines
1.3 KiB

import numpy as np
import faiss
import json
import math
d = 768 # dimension
# nlist = 1000 #聚类的数目
with open("data/discipline_types.json") as f:
lable_discipline_types = json.loads(f.read())
a = 0
for leibie_zh in lable_discipline_types:
xb = np.load(f'data/prompt_qikan/{lable_discipline_types[leibie_zh]}.npy')
xb_1 = np.load(f'data/prompt_master/{lable_discipline_types[leibie_zh]}.npy')
xb_2 = np.load(f'data/prompt_doctor/{lable_discipline_types[leibie_zh]}.npy')
print("xb.shape", xb.shape)
print("xb_1.shape", xb_1.shape)
print("xb_2.shape", xb_2.shape)
xb = np.concatenate((xb, xb_1, xb_2))
# nlist = math.floor((len(lable_discipline_types[leibie_zh]) ** 0.5)) # 聚类的数目
# print(leibie_zh)
# print(len(lable_discipline_types[leibie_zh]))
# print(nlist)
print(xb.shape)
nlist = math.floor((xb.shape[0] ** 0.5))
a += xb.shape[0]
print(nlist)
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)
assert not index.is_trained
index.train(xb) # IndexIVFFlat是需要训练的,这边是学习聚类
assert index.is_trained
faiss.write_index(index, f'data/prompt_qikan_master_doctor_ivf/{lable_discipline_types[leibie_zh]}.ivf')
print(a)