import numpy as np import faiss import json import math d = 768 # dimension # nlist = 1000 #聚类的数目 with open("data/discipline_types.json") as f: lable_discipline_types = json.loads(f.read()) a = 0 for leibie_zh in lable_discipline_types: xb = np.load(f'data/prompt_qikan/{lable_discipline_types[leibie_zh]}.npy') # nlist = math.floor((len(lable_discipline_types[leibie_zh]) ** 0.5)) # 聚类的数目 # print(leibie_zh) # print(len(lable_discipline_types[leibie_zh])) # print(nlist) print(xb.shape) nlist = math.floor((xb.shape[0] ** 0.5)) a += xb.shape[0] print(nlist) quantizer = faiss.IndexFlatL2(d) index = faiss.IndexIVFFlat(quantizer, d, nlist) assert not index.is_trained index.train(xb) # IndexIVFFlat是需要训练的,这边是学习聚类 assert index.is_trained faiss.write_index(index, f'data/prompt_qikan_ivf/{lable_discipline_types[leibie_zh]}.ivf') print(a)