简介
在计算向量相似度时,常用 近似最近邻(ANN, Approximate Nearest Neighbor)算法 来加速查询向量的搜索。其中,较为知名的 ANN 算法包括 HNSW、Ivfflat、Ivfpq 和 Ivfsq。在 IVF(倒排索引,Inverted File Index) 类型的算法中,Elkan K-Means 算法是较为经典的方法之一,并被广泛用于向量聚类和索引构建。
在 PostgreSQL 中,pgvector 插件提供了对向量数据的索引与搜索支持,而 Elkan K-Means 算法正是其中用于优化 IVF 聚类过程的关键技术。接下来,我们将深入解析 Elkan K-Means 算法的具体执行流程。
算法详情
C语言编写的完整代码过程(其中的宏和其他的内容暂不解释,主要讲解代码逻辑):
static void
ElkanKmeans(Relation index, VectorArray samples, VectorArray centers, const IvfTypeInfo * typeInfo)
{FmgrInfo *procinfo;FmgrInfo *normprocinfo;Oid collation;int dimensions = centers->dim;int numCenters = centers->maxlen;int numSamples = samples->length;VectorArray newCenters;float *agg;int *centerCounts;int *closestCenters;float *lowerBound;float *upperBound;float *s;float *halfcdist;float *newcdist;/* Calculate allocation sizes */Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->itemsize);Size centersSize = VECTOR_ARRAY_SIZE(centers->maxlen, centers->itemsize);Size newCentersSize = VECTOR_ARRAY_SIZE(numCenters, centers->itemsize);Size aggSize = sizeof(float) * (int64) numCenters * dimensions;Size centerCountsSize = sizeof(int) * numCenters;Size closestCentersSize = sizeof(int) * numSamples;Size lowerBoundSize = sizeof(float) * numSamples * numCenters;Size upperBoundSize = sizeof(float) * numSamples;Size sSize = sizeof(float) * numCenters;Size halfcdistSize = sizeof(float) * numCenters * numCenters;Size newcdistSize = sizeof(float) * numCenters;/* Calculate total size */Size totalSize = samplesSize + centersSize + newCentersSize + aggSize + centerCountsSize + closestCentersSize + lowerBoundSize + upperBoundSize + sSize + halfcdistSize + newcdistSize;/* Check memory requirements *//* Add one to error message to ceil */if (totalSize > (Size) maintenance_work_mem * 1024L)ereport(ERROR,(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),errmsg("memory required is %zu MB, maintenance_work_mem is %d MB",totalSize / (1024 * 1024) + 1, maintenance_work_mem / 1024)));/* Ensure indexing does not overflow */if (numCenters * numCenters > INT_MAX)elog(ERROR, "Indexing overflow detected. Please report a bug.");/* Set support functions */procinfo = index_getprocinfo(index, 1, IVF_KMEANS_DISTANCE_PROC);normprocinfo = IvfOptionalProcInfo(index, IVF_KMEANS_NORM_PROC);collation = index->rd_indcollation[0];/* Allocate space *//* Use float instead of double to save memory */agg = palloc(aggSize);centerCounts = palloc(centerCountsSize);closestCenters = palloc(closestCentersSize);lowerBound = palloc_extended(lowerBoundSize, MCXT_ALLOC_HUGE);upperBound = palloc(upperBoundSize);s = palloc(sSize);halfcdist = palloc_extended(halfcdistSize, MCXT_ALLOC_HUGE);newcdist = palloc(newcdistSize);/* Initialize new centers */newCenters = VectorArrayInit(numCenters, dimensions, centers->itemsize);newCenters->length = numCenters;#ifdef IVFFLAT_MEMORYShowMemoryUsage(MemoryContextGetParent(CurrentMemoryContext));
#elif defined IVFPQ_MEMORY ShowMemoryUsage(MemoryContextGetParent(CurrentMemoryContext));
#endif/* Pick initial centers */InitCenters(index, samples, centers, lowerBound);/* Assign each x to its closest initial center c(x) = argmin d(x,c) */for (int64 j = 0; j < numSamples; j++){float minDistance = FLT_MAX;int closestCenter = 0;/* Find closest center */for (int64 k = 0; k < numCenters; k++){/* TODO Use Lemma 1 in k-means++ initialization */float distance = lowerBound[j * numCenters + k];if (distance < minDistance){minDistance = distance;closestCenter = k;}}upperBound[j] = minDistance;closestCenters[j] = closestCenter;}/* Give 500 iterations to converge */for (int iteration = 0; iteration < 500; iteration++){int changes = 0;bool rjreset;/* Can take a while, so ensure we can interrupt */CHECK_FOR_INTERRUPTS();/* Step 1: For all centers, compute distance */for (int64 j = 0; j < numCenters; j++){Datum vec = PointerGetDatum(VectorArrayGet(centers, j));for (int64 k = j + 1; k < numCenters; k++){float distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, vec, PointerGetDatum(VectorArrayGet(centers, k))));halfcdist[j * numCenters + k] = distance;halfcdist[k * numCenters + j] = distance;}}/* For all centers c, compute s(c) */for (int64 j = 0; j < numCenters; j++){float minDistance = FLT_MAX;for (int64 k = 0; k < numCenters; k++){