Product Quantization for Nearest Neighbor Search

预备知识

核心思想

20250728195252

FAISS


# %%

import numpy as np
import faiss

# %%

d = 64                           # dimension
nb = 100000                      # database size
nq = 10000                       # nb of queries
np.random.seed(1234)             # make reproducible
xb = np.random.random((nb, d)).astype('float32')
xb[:, 0] += np.arange(nb) / 1000.
xq = np.random.random((nq, d)).astype('float32')
xq[:, 0] += np.arange(nq) / 1000.

xb[:5], xq[:5]
# %%

# 一般的精确搜索, O(nD) 复杂度
# k: 返回最近的 k 个 neighbors

index = faiss.IndexFlatL2(d)   # build the index
print(index.is_trained)
index.add(xb)                  # add vectors to the index
print(index.ntotal)

k = 4                          # we want to see 4 nearest neighbors
D, I = index.search(xb[:5], k) # sanity check
print(I)
print(D)                       # the first columns are zero
D, I = index.search(xq, k)     # actual search
print(I[:5])                   # neighbors of the 5 first queries
print(I[-5:])                  # neighbors of the 5 last queries

# True
# 100000
# [[  0 393 363  78]
#  [  1 555 277 364]
#  [  2 304 101  13]
#  [  3 173  18 182]
#  [  4 288 370 531]]
# [[0.        7.1751733 7.2076297 7.2511625]
#  [0.        6.3235645 6.684581  6.799946 ]
#  [0.        5.7964087 6.391736  7.2815123]
#  [0.        7.2779055 7.5279875 7.662846 ]
#  [0.        6.7638035 7.2951202 7.368815 ]]
# [[ 381  207  210  477]
#  [ 526  911  142   72]
#  [ 838  527 1290  425]
#  [ 196  184  164  359]
#  [ 526  377  120  425]]
# [[ 9900 10500  9309  9831]
#  [11055 10895 10812 11321]
#  [11353 11103 10164  9787]
#  [10571 10664 10632  9638]
#  [ 9628  9554 10036  9582]]

# %%

# 通过向量量化得到 Voronoi cells
# 1. 通过 cell 的中心进行初筛
# 2. 在候选的 cells 中进行精确匹配
# nlist: number of cells
# nprobe: the number of cells that are visited to perform a search

nlist = 100
k = 4
quantizer = faiss.IndexFlatL2(d)  # the other index
index = faiss.IndexIVFFlat(quantizer, d, nlist)
assert not index.is_trained
index.train(xb)
assert index.is_trained

index.add(xb)                  # add may be a bit slower as well
D, I = index.search(xq, k)     # actual search
print(I[-5:])                  # neighbors of the 5 last queries
print(D[-5:])
index.nprobe = 10              # default nprobe is 1, try a few more
D, I = index.search(xq, k)
print(I[-5:])                  # neighbors of the 5 last queries
print(D[-5:])

# [[ 9900  9309  9810 10048]
#  [11055 10895 10812 11321]
#  [11353 10164  9787 10719]
#  [10571 10664 10632 10203]
#  [ 9628  9554  9582 10304]]
# [[6.531542  7.003928  7.1329813 7.32678  ]
#  [4.3351846 5.2369733 5.3194113 5.7031565]
#  [6.0726957 6.6140213 6.732214  6.967677 ]
#  [6.637367  6.6487756 6.8578253 7.091343 ]
#  [6.218346  6.4524803 6.581311  6.582589 ]]
# [[ 9900 10500  9309  9831]
#  [11055 10895 10812 11321]
#  [11353 11103 10164  9787]
#  [10571 10664 10632  9638]
#  [ 9628  9554 10036  9582]]
# [[6.531542  6.978715  7.003928  7.013735 ]
#  [4.3351846 5.2369733 5.3194113 5.7031565]
#  [6.0726957 6.576689  6.6140213 6.732214 ]
#  [6.637367  6.6487756 6.8578253 7.009651 ]
#  [6.218346  6.4524803 6.5487304 6.581311 ]]


# %%

# 通过 Product Quantization 进一步压缩存储占用

nlist = 100
m = 8                             # number of subquantizers
k = 4
quantizer = faiss.IndexFlatL2(d)  # this remains the same
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
                                    # 8 specifies that each sub-vector is encoded as 8 bits
index.train(xb)
index.add(xb)
D, I = index.search(xb[:5], k) # sanity check
print(I)
print(D)
index.nprobe = 10              # make comparable with experiment above
D, I = index.search(xq, k)     # search
print(I[-5:])


# [[   0   78  159  424]
#  [   1  555 1063    5]
#  [   2  304  179  134]
#  [   3  182  139  265]
#  [   4  288  531   95]]
# [[1.6623363 5.9748235 6.432223  6.598652 ]
#  [1.3124933 5.700293  5.9881134 6.2415333]
#  [1.8071327 5.6087813 6.116017  6.2952175]
#  [1.7520823 6.575944  6.6075144 6.7987113]
#  [1.4028182 5.7459674 6.2515984 6.298007 ]]
# [[10437 10560  9842  9432]
#  [11373 10507  9014 10494]
#  [10719 11291 10494 11383]
#  [10630  9671 10972 10040]
#  [10304  9878  9229 10370]]

参考文献

  1. Jegou H., Douze M. and Schmid C. Product Quantization for Nearest Neighbor Search. TPAMI, 2011. [PDF] [Code]