You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
40 lines
1.3 KiB
40 lines
1.3 KiB
import numpy as np
|
|
import faiss
|
|
import json
|
|
import math
|
|
|
|
|
|
d = 768 # dimension
|
|
# nlist = 1000 #聚类的数目
|
|
|
|
|
|
with open("data/discipline_types.json") as f:
|
|
lable_discipline_types = json.loads(f.read())
|
|
|
|
a = 0
|
|
for leibie_zh in lable_discipline_types:
|
|
|
|
xb = np.load(f'data/prompt_qikan/{lable_discipline_types[leibie_zh]}.npy')
|
|
xb_1 = np.load(f'data/prompt_master/{lable_discipline_types[leibie_zh]}.npy')
|
|
xb_2 = np.load(f'data/prompt_doctor/{lable_discipline_types[leibie_zh]}.npy')
|
|
print("xb.shape", xb.shape)
|
|
print("xb_1.shape", xb_1.shape)
|
|
print("xb_2.shape", xb_2.shape)
|
|
xb = np.concatenate((xb, xb_1, xb_2))
|
|
|
|
# nlist = math.floor((len(lable_discipline_types[leibie_zh]) ** 0.5)) # 聚类的数目
|
|
# print(leibie_zh)
|
|
# print(len(lable_discipline_types[leibie_zh]))
|
|
# print(nlist)
|
|
|
|
print(xb.shape)
|
|
nlist = math.floor((xb.shape[0] ** 0.5))
|
|
a += xb.shape[0]
|
|
print(nlist)
|
|
quantizer = faiss.IndexFlatL2(d)
|
|
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
|
assert not index.is_trained
|
|
index.train(xb) # IndexIVFFlat是需要训练的,这边是学习聚类
|
|
assert index.is_trained
|
|
faiss.write_index(index, f'data/prompt_qikan_master_doctor_ivf/{lable_discipline_types[leibie_zh]}.ivf')
|
|
print(a)
|