HNSW(Hierarchical Navigable Small World)是目前最先进的近似最近邻(ANN)搜索算法之一,由Malkov和Yashunin在2018年提出。它结合了:

小世界网络具有以下重要特性:
每个节点的层级按照以下概率分布确定:
其中
是自然对数,这确保了:
package com.jvector.index;
import com.jvector.core.Vector;
import java.io.Serializable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.Set;
import java.util.Map;
/**
* HNSW网络中的节点
*/
public class HnswNode implements Serializable {
private static final long serialVersionUID = 1L;
private final long id;
private final Vector vector;
private final int level;
// 每层的连接关系,使用ConcurrentHashMap确保线程安全
private final Map<Integer, Set<Long>> connections;
public HnswNode(long id, Vector vector, int level) {
this.id = id;
this.vector = vector;
this.level = level;
this.connections = new ConcurrentHashMap<>();
// 初始化每层的连接集合
for (int i = 0; i <= level; i++) {
connections.put(i, ConcurrentHashMap.newKeySet());
}
}
/**
* 获取节点ID
*/
public long getId() {
return id;
}
/**
* 获取节点向量
*/
public Vector getVector() {
return vector;
}
/**
* 获取节点层级
*/
public int getLevel() {
return level;
}
/**
* 获取指定层的连接
*/
public Set<Long> getConnections(int layerLevel) {
return connections.get(layerLevel);
}
/**
* 添加连接
*/
public void addConnection(int layerLevel, long nodeId) {
Set<Long> layerConnections = connections.get(layerLevel);
if (layerConnections != null) {
layerConnections.add(nodeId);
}
}
/**
* 移除连接
*/
public void removeConnection(int layerLevel, long nodeId) {
Set<Long> layerConnections = connections.get(layerLevel);
if (layerConnections != null) {
layerConnections.remove(nodeId);
}
}
/**
* 获取指定层的连接数量
*/
public int getConnectionCount(int layerLevel) {
Set<Long> layerConnections = connections.get(layerLevel);
return layerConnections != null ? layerConnections.size() : 0;
}
}
/**
* 分配节点层级
* 使用指数分布确保层级的合理分布
*/
private int assignLevel() {
double levelMultiplier = 1.0 / Math.log(2.0);
Random random = new Random(seed);
// 使用指数分布生成层级
int level = (int) (-Math.log(random.nextDouble()) * levelMultiplier);
// 限制最大层级
return Math.min(level, maxLevel);
}
/**
* 添加向量到索引
*/
public void add(long id, Vector vector) {
ensureLockInitialized();
globalLock.writeLock().lock();
try {
if (nodes.containsKey(id)) {
throw new IllegalArgumentException("Vector with id " + id + " already exists");
}
// 1. 分配层级
int nodeLevel = assignLevel();
HnswNode newNode = new HnswNode(id, vector, nodeLevel);
nodes.put(id, newNode);
// 2. 如果是第一个节点,设为入口点
if (entryPointId == null) {
entryPointId = id;
return;
}
// 3. 从顶层搜索到新节点层级
List<SearchResult> entryPoints = searchLayerEf(vector, entryPointId, 1, nodeLevel + 1);
// 4. 在每一层建立连接
for (int level = Math.min(nodeLevel, getEntryPointLevel()); level >= 0; level--) {
List<SearchResult> candidates = searchLayerEf(vector, entryPoints.get(0).getId(), efConstruction, level);
// 选择最优连接
int levelM = (level == 0) ? maxM : M;
connectNewNode(newNode, candidates, level, levelM);
// 更新入口点为当前层的最佳候选
entryPoints = candidates.subList(0, Math.min(1, candidates.size()));
}
// 5. 更新全局入口点
if (nodeLevel > getEntryPointLevel()) {
entryPointId = id;
}
} finally {
globalLock.writeLock().unlock();
}
}
/**
* 为新节点建立连接
*/
private void connectNewNode(HnswNode newNode, List<SearchResult> candidates, int level, int m) {
// 按距离排序候选节点
candidates.sort(Comparator.comparingDouble(SearchResult::getDistance));
// 选择前M个作为连接
List<SearchResult> selectedConnections = selectBestConnections(candidates, m);
// 建立双向连接
for (SearchResult result : selectedConnections) {
long candidateId = result.getId();
// 新节点连接到候选节点
newNode.addConnection(level, candidateId);
// 候选节点连接到新节点
HnswNode candidateNode = nodes.get(candidateId);
if (candidateNode != null) {
candidateNode.addConnection(level, newNode.getId());
// 如果候选节点连接数超限,需要剪枝
if (candidateNode.getConnectionCount(level) > m) {
pruneConnections(candidateNode, level, m);
}
}
}
}
/**
* 选择最佳连接(启发式算法)
*/
private List<SearchResult> selectBestConnections(List<SearchResult> candidates, int m) {
if (candidates.size() <= m) {
return new ArrayList<>(candidates);
}
List<SearchResult> selected = new ArrayList<>();
List<SearchResult> remaining = new ArrayList<>(candidates);
// 贪心选择算法:优先选择距离近且不冗余的连接
while (selected.size() < m && !remaining.isEmpty()) {
SearchResult best = null;
double bestScore = Double.NEGATIVE_INFINITY;
for (SearchResult candidate : remaining) {
double score = calculateConnectionScore(candidate, selected);
if (score > bestScore) {
bestScore = score;
best = candidate;
}
}
if (best != null) {
selected.add(best);
remaining.remove(best);
}
}
return selected;
}
/**
* 计算连接得分(距离 + 多样性)
*/
private double calculateConnectionScore(SearchResult candidate, List<SearchResult> selected) {
double distanceScore = 1.0 / (1.0 + candidate.getDistance());
// 计算与已选择节点的多样性
double diversityScore = 1.0;
if (!selected.isEmpty()) {
float minDistance = Float.MAX_VALUE;
for (SearchResult s : selected) {
float dist = computeEngine.distance(
candidate.getVector().getData(),
s.getVector().getData(),
distanceMetric
);
minDistance = Math.min(minDistance, dist);
}
diversityScore = minDistance;
}
return distanceScore * diversityScore;
}
/**
* 在指定层搜索最近邻
*/
private List<SearchResult> searchLayerEf(Vector query, Long entryId, int efLocal, int targetLevel) {
Set<Long> visited = new HashSet<>();
// 使用优先队列维护候选集
PriorityQueue<SearchResult> candidates = new PriorityQueue<>(
Comparator.comparingDouble(SearchResult::getDistance));
PriorityQueue<SearchResult> dynamicList = new PriorityQueue<>(
Comparator.comparingDouble(SearchResult::getDistance).reversed());
// 初始化入口点
HnswNode entryNode = nodes.get(entryId);
if (entryNode == null) {
return Collections.emptyList();
}
float distance = computeEngine.distance(
query.getData(),
entryNode.getVector().getData(),
distanceMetric
);
SearchResult entryResult = new SearchResult(entryId, distance, entryNode.getVector());
candidates.add(entryResult);
dynamicList.add(entryResult);
visited.add(entryId);
// 贪心搜索
while (!candidates.isEmpty()) {
SearchResult current = candidates.poll();
// 如果当前距离大于动态列表中的最远距离,停止搜索
if (dynamicList.size() >= efLocal &&
current.getDistance() > dynamicList.peek().getDistance()) {
break;
}
// 检查当前节点的所有邻居
HnswNode currentNode = nodes.get(current.getId());
Set<Long> connections = currentNode.getConnections(targetLevel);
if (connections != null) {
for (Long neighborId : connections) {
if (!visited.contains(neighborId)) {
visited.add(neighborId);
HnswNode neighborNode = nodes.get(neighborId);
if (neighborNode != null) {
float neighborDistance = computeEngine.distance(
query.getData(),
neighborNode.getVector().getData(),
distanceMetric
);
SearchResult neighborResult = new SearchResult(
neighborId, neighborDistance, neighborNode.getVector());
// 如果发现更好的候选或动态列表未满,加入候选集
if (dynamicList.size() < efLocal ||
neighborDistance < dynamicList.peek().getDistance()) {
candidates.add(neighborResult);
dynamicList.add(neighborResult);
// 保持动态列表大小
if (dynamicList.size() > efLocal) {
dynamicList.poll();
}
}
}
}
}
}
}
// 返回结果
List<SearchResult> results = new ArrayList<>(dynamicList);
results.sort(Comparator.comparingDouble(SearchResult::getDistance));
return results;
}
/**
* 搜索K个最近邻
*/
public List<SearchResult> search(Vector query, int k, int searchEf) {
ensureLockInitialized();
globalLock.readLock().lock();
try {
if (entryPointId == null) {
return Collections.emptyList();
}
HnswNode entryPoint = nodes.get(entryPointId);
Long currentEntry = entryPointId;
// 阶段1:从顶层搜索到第1层,每层只保留1个最佳候选
for (int level = entryPoint.getLevel(); level > 0; level--) {
List<SearchResult> results = searchLayerEf(query, currentEntry, 1, level);
if (!results.isEmpty()) {
currentEntry = results.get(0).getId();
}
}
// 阶段2:在第0层进行详细搜索
List<SearchResult> finalResults = searchLayerEf(query, currentEntry,
Math.max(searchEf, k), 0);
// 返回前K个结果
finalResults.sort(Comparator.comparingDouble(SearchResult::getDistance));
return finalResults.subList(0, Math.min(k, finalResults.size()));
} finally {
globalLock.readLock().unlock();
}
}
/**
* 剪枝连接,保持连接质量
*/
private void pruneConnections(HnswNode node, int level, int maxConnections) {
Set<Long> connections = node.getConnections(level);
if (connections.size() <= maxConnections) {
return;
}
// 收集所有连接的距离信息
List<SearchResult> connectionResults = new ArrayList<>();
for (Long connId : connections) {
HnswNode connNode = nodes.get(connId);
if (connNode != null) {
float distance = computeEngine.distance(
node.getVector().getData(),
connNode.getVector().getData(),
distanceMetric
);
connectionResults.add(new SearchResult(connId, distance, connNode.getVector()));
}
}
// 使用启发式算法选择保留的连接
List<SearchResult> toKeep = selectBestConnections(connectionResults, maxConnections);
// 更新连接
connections.clear();
for (SearchResult result : toKeep) {
connections.add(result.getId());
}
}
/**
* 删除指定ID的向量
*/
public boolean remove(long id) {
ensureLockInitialized();
globalLock.writeLock().lock();
try {
HnswNode nodeToRemove = nodes.get(id);
if (nodeToRemove == null) {
return false;
}
// 1. 移除所有连接关系
removeNodeConnections(nodeToRemove);
// 2. 从节点映射中删除
nodes.remove(id);
// 3. 如果删除的是入口点,需要更新入口点
if (entryPointId != null && entryPointId.equals(id)) {
updateEntryPoint();
}
return true;
} finally {
globalLock.writeLock().unlock();
}
}
/**
* 移除节点的所有连接
*/
private void removeNodeConnections(HnswNode nodeToRemove) {
for (int level = 0; level <= nodeToRemove.getLevel(); level++) {
Set<Long> connections = nodeToRemove.getConnections(level);
if (connections != null) {
// 移除双向连接
for (Long connectedId : new HashSet<>(connections)) {
HnswNode connectedNode = nodes.get(connectedId);
if (connectedNode != null) {
connectedNode.removeConnection(level, nodeToRemove.getId());
}
}
connections.clear();
}
}
}
/**
* 更新入口点
*/
private void updateEntryPoint() {
entryPointId = null;
int maxLevel = -1;
// 找到层级最高的节点作为新的入口点
for (HnswNode node : nodes.values()) {
if (node.getLevel() > maxLevel) {
maxLevel = node.getLevel();
entryPointId = node.getId();
}
}
}
/**
* 预取优化,减少缓存缺失
*/
private void prefetchNodes(Set<Long> nodeIds) {
// 批量加载节点数据到CPU缓存
for (Long nodeId : nodeIds) {
HnswNode node = nodes.get(nodeId);
if (node != null) {
// 访问节点数据,触发缓存加载
Vector vector = node.getVector();
float[] data = vector.getData();
// 简单的内存访问以确保数据在缓存中
@SuppressWarnings("unused")
float sum = data[0] + data[data.length - 1];
}
}
}
/**
* 并行搜索实现
*/
public List<SearchResult> parallelSearch(Vector query, int k, int searchEf) {
if (Runtime.getRuntime().availableProcessors() == 1) {
return search(query, k, searchEf);
}
ensureLockInitialized();
globalLock.readLock().lock();
try {
// 使用Fork/Join框架进行并行搜索
ForkJoinPool forkJoinPool = new ForkJoinPool();
ParallelSearchTask task = new ParallelSearchTask(query, k, searchEf, 0, getEntryPointLevel());
return forkJoinPool.invoke(task);
} finally {
globalLock.readLock().unlock();
}
}
/**
* Fork/Join并行搜索任务
*/
private class ParallelSearchTask extends RecursiveTask<List<SearchResult>> {
private final Vector query;
private final int k;
private final int searchEf;
private final int fromLevel;
private final int toLevel;
public ParallelSearchTask(Vector query, int k, int searchEf, int fromLevel, int toLevel) {
this.query = query;
this.k = k;
this.searchEf = searchEf;
this.fromLevel = fromLevel;
this.toLevel = toLevel;
}
@Override
protected List<SearchResult> compute() {
if (toLevel - fromLevel <= 1) {
// 直接搜索
return searchLayerEf(query, entryPointId, searchEf, fromLevel);
} else {
// 分解任务
int midLevel = (fromLevel + toLevel) / 2;
ParallelSearchTask leftTask = new ParallelSearchTask(query, k, searchEf, fromLevel, midLevel);
ParallelSearchTask rightTask = new ParallelSearchTask(query, k, searchEf, midLevel, toLevel);
leftTask.fork();
List<SearchResult> rightResults = rightTask.compute();
List<SearchResult> leftResults = leftTask.join();
// 合并结果
return mergeResults(leftResults, rightResults, k);
}
}
}
本章深入介绍了HNSW算法的核心原理和实现:
HNSW算法是向量搜索引擎的核心,其分层结构和小世界网络特性使其在保证搜索精度的同时,实现了优异的搜索性能。
思考题: