HNSW索引在整个生命周期中会经历多个状态,需要合理的状态管理机制:

package com.jvector.index;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
/**
* 索引状态管理器
*/
public class IndexStateManager {
public enum IndexState {
CREATED, // 已创建,未构建
BUILDING, // 正在构建
READY, // 就绪,可以搜索
UPDATING, // 正在更新
ERROR, // 错误状态
CLOSED // 已关闭
}
private final AtomicReference<IndexState> currentState = new AtomicReference<>(IndexState.CREATED);
private final ReadWriteLock stateLock = new ReentrantReadWriteLock();
private volatile String errorMessage;
/**
* 状态转换
*/
public boolean transitionTo(IndexState newState) {
stateLock.writeLock().lock();
try {
IndexState current = currentState.get();
if (isValidTransition(current, newState)) {
currentState.set(newState);
if (newState != IndexState.ERROR) {
errorMessage = null;
}
return true;
}
return false;
} finally {
stateLock.writeLock().unlock();
}
}
/**
* 检查状态转换是否有效
*/
private boolean isValidTransition(IndexState from, IndexState to) {
switch (from) {
case CREATED:
return to == IndexState.BUILDING || to == IndexState.CLOSED;
case BUILDING:
return to == IndexState.READY || to == IndexState.ERROR;
case READY:
return to == IndexState.UPDATING || to == IndexState.SEARCHING || to == IndexState.CLOSED;
case UPDATING:
return to == IndexState.READY || to == IndexState.ERROR;
case SEARCHING:
return to == IndexState.READY;
case ERROR:
return to == IndexState.BUILDING || to == IndexState.CLOSED;
case CLOSED:
return false;
default:
return false;
}
}
/**
* 获取当前状态
*/
public IndexState getCurrentState() {
stateLock.readLock().lock();
try {
return currentState.get();
} finally {
stateLock.readLock().unlock();
}
}
/**
* 检查是否可以执行搜索
*/
public boolean canSearch() {
IndexState state = getCurrentState();
return state == IndexState.READY || state == IndexState.SEARCHING;
}
/**
* 检查是否可以更新
*/
public boolean canUpdate() {
return getCurrentState() == IndexState.READY;
}
/**
* 设置错误状态
*/
public void setError(String message) {
this.errorMessage = message;
transitionTo(IndexState.ERROR);
}
/**
* 获取错误信息
*/
public String getErrorMessage() {
return errorMessage;
}
}
/**
* 批量索引构建器
*/
public class BatchIndexBuilder {
private static final Logger logger = LoggerFactory.getLogger(BatchIndexBuilder.class);
private final HnswIndex index;
private final int batchSize;
private final ExecutorService executorService;
private final ProgressListener progressListener;
public BatchIndexBuilder(HnswIndex index, int batchSize) {
this.index = index;
this.batchSize = batchSize;
this.executorService = Executors.newFixedThreadPool(
Runtime.getRuntime().availableProcessors());
this.progressListener = null;
}
/**
* 批量构建索引
*/
public void buildIndex(List<VectorWithId> vectors) {
index.getStateManager().transitionTo(IndexState.BUILDING);
try {
int totalVectors = vectors.size();
int processedVectors = 0;
// 分批处理向量
for (int i = 0; i < totalVectors; i += batchSize) {
int endIndex = Math.min(i + batchSize, totalVectors);
List<VectorWithId> batch = vectors.subList(i, endIndex);
processBatch(batch);
processedVectors += batch.size();
// 更新进度
if (progressListener != null) {
progressListener.onProgress(processedVectors, totalVectors);
}
// 检查内存使用情况
if (shouldTriggerGC()) {
System.gc();
Thread.sleep(100); // 让GC有时间执行
}
}
// 构建完成后的优化
postBuildOptimization();
index.getStateManager().transitionTo(IndexState.READY);
logger.info("Batch index building completed. Total vectors: {}", totalVectors);
} catch (Exception e) {
index.getStateManager().setError("Batch building failed: " + e.getMessage());
throw new RuntimeException("Failed to build index", e);
}
}
/**
* 处理单个批次
*/
private void processBatch(List<VectorWithId> batch) {
// 并行处理批次内的向量
List<Future<Void>> futures = new ArrayList<>();
for (VectorWithId vectorWithId : batch) {
Future<Void> future = executorService.submit(() -> {
index.addVector(vectorWithId.getId(), vectorWithId.getVector());
return null;
});
futures.add(future);
}
// 等待批次完成
for (Future<Void> future : futures) {
try {
future.get();
} catch (Exception e) {
throw new RuntimeException("Failed to process batch", e);
}
}
}
/**
* 构建后优化
*/
private void postBuildOptimization() {
logger.info("Starting post-build optimization...");
// 连接优化
optimizeConnections();
// 内存压缩
compactMemory();
// 缓存预热
warmupCache();
logger.info("Post-build optimization completed");
}
/**
* 连接优化
*/
private void optimizeConnections() {
// 遍历所有节点,优化连接质量
for (HnswNode node : index.getAllNodes()) {
for (int level = 0; level <= node.getLevel(); level++) {
pruneAndOptimizeConnections(node, level);
}
}
}
/**
* 检查是否应该触发GC
*/
private boolean shouldTriggerGC() {
Runtime runtime = Runtime.getRuntime();
long maxMemory = runtime.maxMemory();
long usedMemory = runtime.totalMemory() - runtime.freeMemory();
return (double) usedMemory / maxMemory > 0.8; // 使用超过80%内存时触发GC
}
}
/**
* 内存高效的索引构建器
*/
public class MemoryEfficientBuilder {
private final int memoryLimit;
private final TemporaryStorage tempStorage;
public MemoryEfficientBuilder(int memoryLimitMB) {
this.memoryLimit = memoryLimitMB * 1024 * 1024;
this.tempStorage = new TemporaryStorage();
}
/**
* 大规模数据集构建
*/
public void buildLargeIndex(Iterator<VectorWithId> vectorIterator) {
List<IndexSegment> segments = new ArrayList<>();
// 阶段1:分段构建
while (vectorIterator.hasNext()) {
IndexSegment segment = buildSegment(vectorIterator);
segments.add(segment);
// 将段存储到磁盘
tempStorage.storeSegment(segment);
}
// 阶段2:合并段
IndexSegment finalIndex = mergeSegments(segments);
// 清理临时文件
tempStorage.cleanup();
}
/**
* 构建索引段
*/
private IndexSegment buildSegment(Iterator<VectorWithId> vectorIterator) {
IndexSegment segment = new IndexSegment();
long currentMemoryUsage = 0;
while (vectorIterator.hasNext() && currentMemoryUsage < memoryLimit) {
VectorWithId vector = vectorIterator.next();
segment.addVector(vector);
currentMemoryUsage += estimateVectorMemoryUsage(vector);
}
return segment;
}
/**
* 合并索引段
*/
private IndexSegment mergeSegments(List<IndexSegment> segments) {
// 使用优先队列合并多个有序段
PriorityQueue<SegmentIterator> pq = new PriorityQueue<>();
for (IndexSegment segment : segments) {
pq.offer(new SegmentIterator(segment));
}
IndexSegment mergedSegment = new IndexSegment();
while (!pq.isEmpty()) {
SegmentIterator iterator = pq.poll();
VectorWithId vector = iterator.next();
mergedSegment.addVector(vector);
if (iterator.hasNext()) {
pq.offer(iterator);
}
}
return mergedSegment;
}
}
/**
* 增量更新管理器
*/
public class IncrementalUpdateManager {
private static final Logger logger = LoggerFactory.getLogger(IncrementalUpdateManager.class);
private final HnswIndex index;
private final UpdateBuffer updateBuffer;
private final ScheduledExecutorService scheduler;
private final AtomicLong pendingUpdates = new AtomicLong(0);
// 配置参数
private final int bufferSize;
private final long flushIntervalMs;
private final double updateThreshold;
public IncrementalUpdateManager(HnswIndex index, IncrementalConfig config) {
this.index = index;
this.bufferSize = config.getBufferSize();
this.flushIntervalMs = config.getFlushIntervalMs();
this.updateThreshold = config.getUpdateThreshold();
this.updateBuffer = new UpdateBuffer(bufferSize);
this.scheduler = Executors.newSingleThreadScheduledExecutor();
// 定期刷新缓冲区
scheduler.scheduleWithFixedDelay(
this::flushBuffer,
flushIntervalMs,
flushIntervalMs,
TimeUnit.MILLISECONDS
);
}
/**
* 添加向量(异步)
*/
public CompletableFuture<Long> addVectorAsync(Vector vector) {
UpdateOperation operation = new UpdateOperation(
UpdateOperation.Type.ADD,
null,
vector,
System.currentTimeMillis()
);
return submitUpdate(operation);
}
/**
* 更新向量(异步)
*/
public CompletableFuture<Boolean> updateVectorAsync(long id, Vector vector) {
UpdateOperation operation = new UpdateOperation(
UpdateOperation.Type.UPDATE,
id,
vector,
System.currentTimeMillis()
);
return submitUpdate(operation).thenApply(result -> result != null);
}
/**
* 删除向量(异步)
*/
public CompletableFuture<Boolean> deleteVectorAsync(long id) {
UpdateOperation operation = new UpdateOperation(
UpdateOperation.Type.DELETE,
id,
null,
System.currentTimeMillis()
);
return submitUpdate(operation).thenApply(result -> result != null);
}
/**
* 提交更新操作
*/
private CompletableFuture<Long> submitUpdate(UpdateOperation operation) {
CompletableFuture<Long> future = new CompletableFuture<>();
operation.setFuture(future);
boolean added = updateBuffer.offer(operation);
if (!added) {
// 缓冲区满,立即刷新
flushBuffer();
updateBuffer.offer(operation); // 重试
}
pendingUpdates.incrementAndGet();
// 检查是否需要立即刷新
if (shouldFlushImmediately()) {
flushBuffer();
}
return future;
}
/**
* 刷新更新缓冲区
*/
private synchronized void flushBuffer() {
List<UpdateOperation> operations = updateBuffer.drainAll();
if (operations.isEmpty()) {
return;
}
logger.debug("Flushing {} update operations", operations.size());
try {
index.getStateManager().transitionTo(IndexState.UPDATING);
// 批量执行更新操作
executeBatchUpdates(operations);
index.getStateManager().transitionTo(IndexState.READY);
// 完成所有Future
for (UpdateOperation op : operations) {
op.getFuture().complete(op.getResultId());
}
pendingUpdates.addAndGet(-operations.size());
} catch (Exception e) {
logger.error("Failed to flush update buffer", e);
// 标记所有操作为失败
for (UpdateOperation op : operations) {
op.getFuture().completeExceptionally(e);
}
index.getStateManager().setError("Update flush failed: " + e.getMessage());
}
}
/**
* 批量执行更新操作
*/
private void executeBatchUpdates(List<UpdateOperation> operations) {
// 按操作类型分组
Map<UpdateOperation.Type, List<UpdateOperation>> groupedOps = operations.stream()
.collect(Collectors.groupingBy(UpdateOperation::getType));
// 按顺序执行:删除 -> 更新 -> 添加
executeDeletes(groupedOps.getOrDefault(UpdateOperation.Type.DELETE, Collections.emptyList()));
executeUpdates(groupedOps.getOrDefault(UpdateOperation.Type.UPDATE, Collections.emptyList()));
executeAdds(groupedOps.getOrDefault(UpdateOperation.Type.ADD, Collections.emptyList()));
}
/**
* 执行删除操作
*/
private void executeDeletes(List<UpdateOperation> deleteOps) {
for (UpdateOperation op : deleteOps) {
boolean result = index.remove(op.getId());
op.setResultId(result ? op.getId() : null);
}
}
/**
* 执行更新操作
*/
private void executeUpdates(List<UpdateOperation> updateOps) {
for (UpdateOperation op : updateOps) {
// 先删除再添加(简单的更新策略)
boolean removed = index.remove(op.getId());
if (removed) {
long newId = index.add(op.getVector());
op.setResultId(newId);
} else {
op.setResultId(null);
}
}
}
/**
* 执行添加操作
*/
private void executeAdds(List<UpdateOperation> addOps) {
for (UpdateOperation op : addOps) {
long id = index.add(op.getVector());
op.setResultId(id);
}
}
/**
* 检查是否应该立即刷新
*/
private boolean shouldFlushImmediately() {
return updateBuffer.size() >= bufferSize * updateThreshold;
}
/**
* 关闭更新管理器
*/
public void shutdown() {
scheduler.shutdown();
flushBuffer(); // 最后一次刷新
}
}
/**
* 线程安全的更新缓冲区
*/
public class UpdateBuffer {
private final LinkedBlockingQueue<UpdateOperation> buffer;
private final int capacity;
public UpdateBuffer(int capacity) {
this.capacity = capacity;
this.buffer = new LinkedBlockingQueue<>(capacity);
}
public boolean offer(UpdateOperation operation) {
return buffer.offer(operation);
}
public List<UpdateOperation> drainAll() {
List<UpdateOperation> operations = new ArrayList<>();
buffer.drainTo(operations);
return operations;
}
public int size() {
return buffer.size();
}
public boolean isEmpty() {
return buffer.isEmpty();
}
}
/**
* 更新操作封装
*/
public class UpdateOperation {
public enum Type {
ADD, UPDATE, DELETE
}
private final Type type;
private final Long id;
private final Vector vector;
private final long timestamp;
private CompletableFuture<Long> future;
private Long resultId;
public UpdateOperation(Type type, Long id, Vector vector, long timestamp) {
this.type = type;
this.id = id;
this.vector = vector;
this.timestamp = timestamp;
}
// Getters and setters
public Type getType() { return type; }
public Long getId() { return id; }
public Vector getVector() { return vector; }
public long getTimestamp() { return timestamp; }
public CompletableFuture<Long> getFuture() { return future; }
public void setFuture(CompletableFuture<Long> future) { this.future = future; }
public Long getResultId() { return resultId; }
public void setResultId(Long resultId) { this.resultId = resultId; }
}
/**
* 搜索参数自适应调节器
*/
public class AdaptiveSearchTuner {
private static final Logger logger = LoggerFactory.getLogger(AdaptiveSearchTuner.class);
private final Map<String, ParameterHistory> parameterHistory = new ConcurrentHashMap<>();
private final ParameterOptimizer optimizer;
// 默认参数范围
private static final int MIN_EF = 10;
private static final int MAX_EF = 1000;
private static final double TARGET_RECALL = 0.95;
private static final double MAX_LATENCY_MS = 50.0;
public AdaptiveSearchTuner() {
this.optimizer = new ParameterOptimizer();
}
/**
* 基于查询特征调整搜索参数
*/
public SearchParameters tuneParameters(QueryContext context) {
String querySignature = generateQuerySignature(context);
ParameterHistory history = parameterHistory.get(querySignature);
if (history == null || history.getSampleCount() < 10) {
// 使用默认参数或启发式参数
return generateHeuristicParameters(context);
}
// 基于历史数据优化参数
return optimizer.optimize(history, context);
}
/**
* 记录搜索结果反馈
*/
public void recordFeedback(QueryContext context, SearchParameters params, SearchMetrics metrics) {
String querySignature = generateQuerySignature(context);
ParameterHistory history = parameterHistory.computeIfAbsent(
querySignature, k -> new ParameterHistory());
ParameterSample sample = new ParameterSample(
params, metrics, System.currentTimeMillis());
history.addSample(sample);
// 限制历史记录大小
if (history.getSampleCount() > 1000) {
history.removeOldestSamples(100);
}
}
/**
* 生成查询签名
*/
private String generateQuerySignature(QueryContext context) {
// 基于查询向量维度、k值、数据集大小等生成唯一签名
return String.format("dim_%d_k_%d_size_%d",
context.getQueryDimension(),
context.getK(),
context.getIndexSize() / 10000 * 10000); // 按万取整
}
/**
* 生成启发式参数
*/
private SearchParameters generateHeuristicParameters(QueryContext context) {
int k = context.getK();
int indexSize = context.getIndexSize();
// 基于经验公式计算ef值
int ef = Math.max(k * 2, Math.min(MAX_EF, (int) (k * Math.log(indexSize) / Math.log(2))));
ef = Math.max(MIN_EF, ef);
return new SearchParameters(ef);
}
}
/**
* 参数优化器
*/
public class ParameterOptimizer {
/**
* 基于历史数据优化参数
*/
public SearchParameters optimize(ParameterHistory history, QueryContext context) {
List<ParameterSample> samples = history.getRecentSamples(100);
// 使用贝叶斯优化或网格搜索找到最优参数
OptimizationResult result = bayesianOptimization(samples, context);
return new SearchParameters(result.getOptimalEf());
}
/**
* 贝叶斯优化实现
*/
private OptimizationResult bayesianOptimization(List<ParameterSample> samples, QueryContext context) {
// 简化的贝叶斯优化实现
double bestScore = Double.NEGATIVE_INFINITY;
int bestEf = MIN_EF;
// 在有效范围内搜索最优ef值
for (int ef = MIN_EF; ef <= MAX_EF; ef += 10) {
double score = evaluateParameters(ef, samples);
if (score > bestScore) {
bestScore = score;
bestEf = ef;
}
}
return new OptimizationResult(bestEf, bestScore);
}
/**
* 评估参数组合
*/
private double evaluateParameters(int ef, List<ParameterSample> samples) {
// 计算在类似ef值下的平均性能
double totalScore = 0;
int count = 0;
for (ParameterSample sample : samples) {
if (Math.abs(sample.getParameters().getEf() - ef) <= 20) {
SearchMetrics metrics = sample.getMetrics();
// 组合得分:recall权重0.7,延迟权重0.3
double recallScore = metrics.getRecall();
double latencyScore = Math.max(0, 1.0 - metrics.getLatencyMs() / MAX_LATENCY_MS);
double score = 0.7 * recallScore + 0.3 * latencyScore;
totalScore += score;
count++;
}
}
return count > 0 ? totalScore / count : 0.0;
}
}
/**
* 查询上下文
*/
public class QueryContext {
private final int queryDimension;
private final int k;
private final int indexSize;
private final String distanceMetric;
private final long timestamp;
public QueryContext(int queryDimension, int k, int indexSize, String distanceMetric) {
this.queryDimension = queryDimension;
this.k = k;
this.indexSize = indexSize;
this.distanceMetric = distanceMetric;
this.timestamp = System.currentTimeMillis();
}
// Getters
public int getQueryDimension() { return queryDimension; }
public int getK() { return k; }
public int getIndexSize() { return indexSize; }
public String getDistanceMetric() { return distanceMetric; }
public long getTimestamp() { return timestamp; }
}
/**
* 搜索度量
*/
public class SearchMetrics {
private final double recall;
private final double precision;
private final double latencyMs;
private final int visitedNodes;
private final int distanceCalculations;
public SearchMetrics(double recall, double precision, double latencyMs,
int visitedNodes, int distanceCalculations) {
this.recall = recall;
this.precision = precision;
this.latencyMs = latencyMs;
this.visitedNodes = visitedNodes;
this.distanceCalculations = distanceCalculations;
}
// Getters
public double getRecall() { return recall; }
public double getPrecision() { return precision; }
public double getLatencyMs() { return latencyMs; }
public int getVisitedNodes() { return visitedNodes; }
public int getDistanceCalculations() { return distanceCalculations; }
}
/**
* 搜索参数
*/
public class SearchParameters {
private final int ef;
private final boolean useApproximation;
private final int maxVisitedNodes;
public SearchParameters(int ef) {
this(ef, true, -1);
}
public SearchParameters(int ef, boolean useApproximation, int maxVisitedNodes) {
this.ef = ef;
this.useApproximation = useApproximation;
this.maxVisitedNodes = maxVisitedNodes;
}
// Getters
public int getEf() { return ef; }
public boolean isUseApproximation() { return useApproximation; }
public int getMaxVisitedNodes() { return maxVisitedNodes; }
}
本章介绍了索引构建与搜索算法的高级特性:
这些高级特性使得HNSW索引能够在生产环境中稳定高效地运行,支持大规模数据的动态更新和优化搜索性能。
思考题: