You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							40 lines
						
					
					
						
							1.3 KiB
						
					
					
				
			
		
		
		
			
			
			
				
					
				
				
					
				
			
		
		
	
	
							40 lines
						
					
					
						
							1.3 KiB
						
					
					
				
								import numpy as np
							 | 
						|
								import faiss
							 | 
						|
								import json
							 | 
						|
								import math
							 | 
						|
								
							 | 
						|
								
							 | 
						|
								d = 768  # dimension
							 | 
						|
								# nlist = 1000 #聚类的数目
							 | 
						|
								
							 | 
						|
								
							 | 
						|
								with open("data/discipline_types.json") as f:
							 | 
						|
								    lable_discipline_types = json.loads(f.read())
							 | 
						|
								
							 | 
						|
								a = 0
							 | 
						|
								for leibie_zh in lable_discipline_types:
							 | 
						|
								
							 | 
						|
								    xb = np.load(f'data/prompt_qikan/{lable_discipline_types[leibie_zh]}.npy')
							 | 
						|
								    xb_1 = np.load(f'data/prompt_master/{lable_discipline_types[leibie_zh]}.npy')
							 | 
						|
								    xb_2 = np.load(f'data/prompt_doctor/{lable_discipline_types[leibie_zh]}.npy')
							 | 
						|
								    print("xb.shape", xb.shape)
							 | 
						|
								    print("xb_1.shape", xb_1.shape)
							 | 
						|
								    print("xb_2.shape", xb_2.shape)
							 | 
						|
								    xb = np.concatenate((xb, xb_1, xb_2))
							 | 
						|
								
							 | 
						|
								    # nlist = math.floor((len(lable_discipline_types[leibie_zh]) ** 0.5))  # 聚类的数目
							 | 
						|
								    # print(leibie_zh)
							 | 
						|
								    # print(len(lable_discipline_types[leibie_zh]))
							 | 
						|
								    # print(nlist)
							 | 
						|
								
							 | 
						|
								    print(xb.shape)
							 | 
						|
								    nlist = math.floor((xb.shape[0] ** 0.5))
							 | 
						|
								    a += xb.shape[0]
							 | 
						|
								    print(nlist)
							 | 
						|
								    quantizer = faiss.IndexFlatL2(d)
							 | 
						|
								    index = faiss.IndexIVFFlat(quantizer, d, nlist)
							 | 
						|
								    assert not index.is_trained
							 | 
						|
								    index.train(xb) # IndexIVFFlat是需要训练的,这边是学习聚类
							 | 
						|
								    assert index.is_trained
							 | 
						|
								    faiss.write_index(index, f'data/prompt_qikan_master_doctor_ivf/{lable_discipline_types[leibie_zh]}.ivf')
							 | 
						|
								print(a)
							 |