ANNOY(Approximate Nearest Neighbors Oh Yeah)算法能够帮助我们高效的查找近邻的 N 个向量。其基本原理:就是将所有向量按照空间进行划分,直到子空间小于等于 K 个向量位置。如下图所示:
随机选择两个向量,在两点的直线的中心垂直画一条直线将样本分割成两部分。接下来,按照这个思路继续划分每个子空间,直到空间向量数量小于等于 K 个为止。
上面的划分过程,也可以看成构建二叉树的过程,如下图所示:
ANNOY 构建的二叉树都是随机构建的,并且 ANNOY 会构建多个这样的随机二叉树,树的数量可以由我们自己来指定。当来了一个新向量时,该向量在每棵树上必然属于某个子空间,假设我们有 5 棵树,则将新向量所在的 5 个子空间中的所有向量中,找出 N 个最相似的向量。
上图中,可以看到新向量所属的不同子空间,并且每个子空间都有多个向量,ANNOY 就是从这些向量中找到 N 个相近的向量返回。
pip install annoy
示例代码:
from annoy import AnnoyIndex import torch # 1. annoy 构建搜索 def test01(): # 构建索引 # metric 为距离度量方法,可选的有: # 向量角: "angular" # 欧式距离: "euclidean" # 曼哈顿距离: "manhattan" # 汉明距离: "hamming" # 点积: "dot" # f 参数表示向量的维度 index = AnnoyIndex(f=3, metric='euclidean') # 插入向量: 第一个参数为插入位置,第二个参数为插入向量 index.add_item(0, torch.tensor([1, 2, 3])) index.add_item(1, torch.tensor([7, 8, 9])) index.add_item(2, torch.tensor([4, 5, 6])) # 构建二叉树 # on_disk_build 方法将树构建到文件中 # build 方法会将树构建到内存中 index.build(n_trees=5) # 查询向量: # 第一个参数为待查询向量的索引 # 第二个参数为要返回的向量个数 # 返回值为已查到向量的索引 find_index = index.get_nns_by_item(i=0, n=2) print('find_index:', find_index) # 第一个参数为待查询的向量 # 第二个参数为要返回的向量个数 # 返回值为已查到向量的索引 find_index = index.get_nns_by_vector(vector=torch.tensor([0, 1, 1]), n=2) print('find_index:', find_index) # 2. annoy 存储加载 def test02(): index = AnnoyIndex(f=3, metric='euclidean') index.add_item(0, torch.tensor([1, 2, 3])) index.add_item(1, torch.tensor([7, 8, 9])) index.add_item(2, torch.tensor([4, 5, 6])) index.build(n_trees=5) # 存储索引 index.save('index.ann') # 加载索引 index = AnnoyIndex(f=3, metric='euclidean') index.load('index.ann') find_index = index.get_nns_by_item(i=0, n=2) print('find_index:', find_index) # 3. annoy 其他函数 def test03(): index = AnnoyIndex(f=3, metric='euclidean') index.add_item(0, torch.tensor([1, 2, 3])) index.add_item(1, torch.tensor([7, 8, 9])) index.add_item(2, torch.tensor([4, 5, 6])) index.build(n_trees=5) # 获得指定索引位置的向量 print('位置向量:', index.get_item_vector(0)) # 返回索引向量个数 print('向量个数:', index.get_n_items()) # 返回索引树的数量 print('树的数量:', index.get_n_trees()) # 获得索引中指定下标位置的两个向量的距离 print('向量距离:', index.get_distance(0, 1)) if __name__ == '__main__': test01() print('-' * 30) test02() print('-' * 30) test03()
程序输出结果:
find_index: [0, 2] find_index: [0, 2] ------------------------------ find_index: [0, 2] ------------------------------ 位置向量: [1.0, 2.0, 3.0] 向量个数: 3 树的数量: 5 向量距离: 10.392304420471191