-1

我想使用 Hierarchical Navigable Small Worlds (HNSW) 和 LSH 进行数据分类。我通过如下调整来使用 Annoy 算法:

#Wrapper for using annoy.AnnoyIndex as sklearn's KNeighborsTransformer
class Annoy():
    def __init__(self,n_neighbors=5, metric='euclidean', n_trees=10):
        self.n_neighbors = n_neighbors
        self.metric = metric
        self.n_trees = n_trees
    
    def fit(self, X_train, y_train):
        self.N_feat = X_train.shape[1]
        self.N_train = X_train.shape[0]
        self.y_train = y_train
        self.t = annoy.AnnoyIndex(self.N_feat,metric=self.metric)
        for i, v in zip(range(self.N_train), X_train):
            self.t.add_item(i, v)
        self.t.build(self.n_trees)
        return self

    def predict(self,X_test):
        y_hat = []
        for tv in X_test:
            nn_inds = self.t.get_nns_by_vector(tv, self.n_neighbors)
            nn_classes =[self.y_train[nn] for nn in nn_inds]
            y_hat.append(most_frequent(nn_classes))
        return y_hat
    
def most_frequent(List): 
    occurence_count = Counter(List) 
    return occurence_count.most_common(1)[0][0] 
4

0 回答 0