对象池是减少频繁对象创建和垃圾回收开销的重要技术,特别适用于高频使用的对象。
package com.jvector.optimization.memory;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
/**
* 通用对象池实现
*/
public class ObjectPool<T> {
private final ConcurrentLinkedQueue<T> pool;
private final Supplier<T> objectFactory;
private final int maxSize;
private final AtomicInteger currentSize;
public ObjectPool(Supplier<T> objectFactory, int maxSize) {
this.objectFactory = objectFactory;
this.maxSize = maxSize;
this.pool = new ConcurrentLinkedQueue<>();
this.currentSize = new AtomicInteger(0);
// 预热池
warmupPool();
}
/**
* 从池中获取对象
*/
public T acquire() {
T object = pool.poll();
if (object == null) {
object = objectFactory.get();
} else {
currentSize.decrementAndGet();
}
return object;
}
/**
* 归还对象到池中
*/
public void release(T object) {
if (object != null && currentSize.get() < maxSize) {
// 重置对象状态(如果需要)
resetObject(object);
if (pool.offer(object)) {
currentSize.incrementAndGet();
}
}
}
/**
* 预热对象池
*/
private void warmupPool() {
int initialSize = Math.min(maxSize / 4, 10);
for (int i = 0; i < initialSize; i++) {
pool.offer(objectFactory.get());
currentSize.incrementAndGet();
}
}
/**
* 重置对象状态
*/
private void resetObject(T object) {
// 具体的重置逻辑由子类或配置决定
if (object instanceof Resetable) {
((Resetable) object).reset();
}
}
/**
* 获取池状态
*/
public PoolStats getStats() {
return new PoolStats(currentSize.get(), maxSize, pool.size());
}
}
/**
* 可重置接口
*/
public interface Resetable {
void reset();
}
/**
* 向量对象池
*/
public class VectorPool extends ObjectPool<float[]> {
private final int vectorDimension;
public VectorPool(int vectorDimension, int maxSize) {
super(() -> new float[vectorDimension], maxSize);
this.vectorDimension = vectorDimension;
}
@Override
public void release(float[] vector) {
if (vector != null && vector.length == vectorDimension) {
// 清零向量数据
Arrays.fill(vector, 0.0f);
super.release(vector);
}
}
}
/**
* 搜索结果对象池
*/
public class SearchResultPool extends ObjectPool<List<SearchResult>> {
public SearchResultPool(int maxSize) {
super(ArrayList::new, maxSize);
}
@Override
public void release(List<SearchResult> results) {
if (results != null) {
results.clear();
super.release(results);
}
}
}
/**
* 多级缓存管理器
*/
public class MultiLevelCacheManager {
private static final Logger logger = LoggerFactory.getLogger(MultiLevelCacheManager.class);
private final Map<String, Cache<?, ?>> caches = new ConcurrentHashMap<>();
private final CacheConfig config;
private final CacheMetrics metrics;
public MultiLevelCacheManager(CacheConfig config) {
this.config = config;
this.metrics = new CacheMetrics();
initializeCaches();
}
/**
* 初始化各级缓存
*/
private void initializeCaches() {
// L1缓存:查询结果缓存
Cache<String, List<SearchResult>> queryCache = Caffeine.newBuilder()
.maximumSize(config.getQueryCacheSize())
.expireAfterWrite(config.getQueryCacheTtl())
.removalListener(this::onQueryCacheRemoval)
.recordStats()
.build();
caches.put("query", queryCache);
// L2缓存:向量数据缓存
Cache<Long, Vector> vectorCache = Caffeine.newBuilder()
.maximumSize(config.getVectorCacheSize())
.expireAfterAccess(config.getVectorCacheTtl())
.weigher(this::calculateVectorWeight)
.recordStats()
.build();
caches.put("vector", vectorCache);
// L3缓存:距离计算结果缓存
Cache<String, Float> distanceCache = Caffeine.newBuilder()
.maximumSize(config.getDistanceCacheSize())
.expireAfterWrite(Duration.ofMinutes(30))
.recordStats()
.build();
caches.put("distance", distanceCache);
}
/**
* 获取查询缓存
*/
@SuppressWarnings("unchecked")
public Cache<String, List<SearchResult>> getQueryCache() {
return (Cache<String, List<SearchResult>>) caches.get("query");
}
/**
* 获取向量缓存
*/
@SuppressWarnings("unchecked")
public Cache<Long, Vector> getVectorCache() {
return (Cache<Long, Vector>) caches.get("vector");
}
/**
* 计算向量权重(用于缓存大小限制)
*/
private int calculateVectorWeight(Long id, Vector vector) {
return 8 + vector.getDimension() * 4; // ID + 向量数据
}
/**
* 查询缓存移除监听器
*/
private void onQueryCacheRemoval(String key, List<SearchResult> value, RemovalCause cause) {
metrics.recordCacheRemoval("query", cause);
logger.debug("Query cache removal: key={}, cause={}", key, cause);
}
/**
* 预热缓存
*/
public void warmupCaches(HnswIndex index) {
logger.info("Warming up caches...");
// 预热向量缓存 - 加载热点数据
warmupVectorCache(index);
// 预热距离缓存 - 预计算常用距离
warmupDistanceCache(index);
logger.info("Cache warmup completed");
}
/**
* 预热向量缓存
*/
private void warmupVectorCache(HnswIndex index) {
Cache<Long, Vector> vectorCache = getVectorCache();
// 加载最近访问的向量
Collection<Long> recentIds = index.getRecentlyAccessedIds(1000);
for (Long id : recentIds) {
Vector vector = index.getVector(id);
if (vector != null) {
vectorCache.put(id, vector);
}
}
}
/**
* 获取缓存统计信息
*/
public Map<String, CacheStats> getCacheStats() {
Map<String, CacheStats> stats = new HashMap<>();
for (Map.Entry<String, Cache<?, ?>> entry : caches.entrySet()) {
stats.put(entry.getKey(), entry.getValue().stats());
}
return stats;
}
/**
* 清理所有缓存
*/
public void clearAllCaches() {
caches.values().forEach(Cache::invalidateAll);
logger.info("All caches cleared");
}
}
/**
* 缓存配置
*/
public class CacheConfig {
private long queryCacheSize = 10000;
private Duration queryCacheTtl = Duration.ofMinutes(10);
private long vectorCacheSize = 100000;
private Duration vectorCacheTtl = Duration.ofHours(1);
private long distanceCacheSize = 1000000;
// Getters and setters
public long getQueryCacheSize() { return queryCacheSize; }
public Duration getQueryCacheTtl() { return queryCacheTtl; }
public long getVectorCacheSize() { return vectorCacheSize; }
public Duration getVectorCacheTtl() { return vectorCacheTtl; }
public long getDistanceCacheSize() { return distanceCacheSize; }
}
/**
* 内存泄漏检测器
*/
public class MemoryLeakDetector {
private static final Logger logger = LoggerFactory.getLogger(MemoryLeakDetector.class);
private final MemoryMXBean memoryBean;
private final ScheduledExecutorService scheduler;
private final Map<String, ReferenceQueue<Object>> referenceQueues;
private final Set<WeakReference<Object>> trackedReferences;
public MemoryLeakDetector() {
this.memoryBean = ManagementFactory.getMemoryMXBean();
this.scheduler = Executors.newSingleThreadScheduledExecutor();
this.referenceQueues = new ConcurrentHashMap<>();
this.trackedReferences = ConcurrentHashMap.newKeySet();
startMonitoring();
}
/**
* 开始内存监控
*/
private void startMonitoring() {
// 定期检查内存使用情况
scheduler.scheduleWithFixedDelay(this::checkMemoryUsage, 30, 30, TimeUnit.SECONDS);
// 定期检查引用队列
scheduler.scheduleWithFixedDelay(this::checkReferenceQueues, 10, 10, TimeUnit.SECONDS);
// 定期触发GC建议
scheduler.scheduleWithFixedDelay(this::suggestGC, 300, 300, TimeUnit.SECONDS);
}
/**
* 检查内存使用情况
*/
private void checkMemoryUsage() {
MemoryUsage heapUsage = memoryBean.getHeapMemoryUsage();
long used = heapUsage.getUsed();
long max = heapUsage.getMax();
double usagePercentage = (double) used / max * 100;
if (usagePercentage > 80) {
logger.warn("High memory usage detected: {:.2f}% ({} / {} bytes)",
usagePercentage, used, max);
// 检查是否有潜在内存泄漏
checkForMemoryLeaks();
}
// 记录内存使用指标
recordMemoryMetrics(used, max, usagePercentage);
}
/**
* 检查潜在内存泄漏
*/
private void checkForMemoryLeaks() {
// 分析堆转储(如果启用)
if (shouldAnalyzeHeapDump()) {
analyzeHeapDump();
}
// 检查对象引用计数
checkObjectReferenceCounts();
// 检查长期存在的对象
checkLongLivedObjects();
}
/**
* 跟踪对象引用
*/
public void trackObject(Object object, String category) {
ReferenceQueue<Object> queue = referenceQueues.computeIfAbsent(
category, k -> new ReferenceQueue<>());
WeakReference<Object> ref = new WeakReference<>(object, queue);
trackedReferences.add(ref);
}
/**
* 检查引用队列
*/
private void checkReferenceQueues() {
for (Map.Entry<String, ReferenceQueue<Object>> entry : referenceQueues.entrySet()) {
String category = entry.getKey();
ReferenceQueue<Object> queue = entry.getValue();
Reference<?> ref;
int collectedCount = 0;
while ((ref = queue.poll()) != null) {
trackedReferences.remove(ref);
collectedCount++;
}
if (collectedCount > 0) {
logger.debug("Collected {} references from category: {}", collectedCount, category);
}
}
}
/**
* 建议垃圾回收
*/
private void suggestGC() {
MemoryUsage beforeGC = memoryBean.getHeapMemoryUsage();
// 建议进行垃圾回收
System.gc();
// 等待GC完成
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
MemoryUsage afterGC = memoryBean.getHeapMemoryUsage();
long freedMemory = beforeGC.getUsed() - afterGC.getUsed();
if (freedMemory > 0) {
logger.info("GC freed {} bytes of memory", freedMemory);
}
}
/**
* 分析堆转储
*/
private void analyzeHeapDump() {
try {
// 生成堆转储文件
String dumpFile = generateHeapDump();
// 分析堆转储(简化实现)
analyzeHeapDumpFile(dumpFile);
} catch (Exception e) {
logger.error("Failed to analyze heap dump", e);
}
}
/**
* 生成堆转储
*/
private String generateHeapDump() throws Exception {
MBeanServer server = ManagementFactory.getPlatformMBeanServer();
HotSpotDiagnosticMXBean diagBean = ManagementFactory.newPlatformMXBeanProxy(
server, "com.sun.management:type=HotSpotDiagnostic", HotSpotDiagnosticMXBean.class);
String fileName = "heapdump_" + System.currentTimeMillis() + ".hprof";
diagBean.dumpHeap(fileName, true);
return fileName;
}
/**
* 记录内存指标
*/
private void recordMemoryMetrics(long used, long max, double percentage) {
// 发送到监控系统
MetricsCollector.getInstance().recordGauge("memory.heap.used", used);
MetricsCollector.getInstance().recordGauge("memory.heap.max", max);
MetricsCollector.getInstance().recordGauge("memory.heap.usage_percentage", percentage);
}
/**
* 获取内存泄漏报告
*/
public MemoryLeakReport getLeakReport() {
return MemoryLeakReport.builder()
.trackedReferences(trackedReferences.size())
.referenceQueues(referenceQueues.size())
.memoryUsage(memoryBean.getHeapMemoryUsage())
.suspiciousObjects(findSuspiciousObjects())
.build();
}
/**
* 查找可疑对象
*/
private List<SuspiciousObject> findSuspiciousObjects() {
// 实现对象泄漏检测逻辑
return Collections.emptyList(); // 简化实现
}
}
/**
* 智能查询缓存
*/
public class IntelligentQueryCache {
private static final Logger logger = LoggerFactory.getLogger(IntelligentQueryCache.class);
private final Cache<String, CachedSearchResult> cache;
private final QuerySimilarityCalculator similarityCalculator;
private final CacheHitPredictor hitPredictor;
public IntelligentQueryCache(CacheConfig config) {
this.cache = Caffeine.newBuilder()
.maximumSize(config.getMaxSize())
.expireAfterWrite(config.getTtl())
.removalListener(this::onRemoval)
.recordStats()
.build();
this.similarityCalculator = new QuerySimilarityCalculator();
this.hitPredictor = new CacheHitPredictor();
}
/**
* 获取搜索结果(智能缓存)
*/
public Optional<List<SearchResult>> get(Vector query, int k, SearchOptions options) {
String queryKey = generateQueryKey(query, k, options);
// 直接缓存命中
CachedSearchResult cachedResult = cache.getIfPresent(queryKey);
if (cachedResult != null) {
return Optional.of(cachedResult.getResults());
}
// 查找相似查询
Optional<CachedSearchResult> similarResult = findSimilarQuery(query, k, options);
if (similarResult.isPresent()) {
CachedSearchResult similar = similarResult.get();
// 调整相似查询的结果
List<SearchResult> adjustedResults = adjustSimilarResults(
similar.getResults(), query, similar.getQuery());
// 缓存调整后的结果
cache.put(queryKey, new CachedSearchResult(query, adjustedResults, System.currentTimeMillis()));
return Optional.of(adjustedResults);
}
return Optional.empty();
}
/**
* 缓存搜索结果
*/
public void put(Vector query, int k, SearchOptions options, List<SearchResult> results) {
String queryKey = generateQueryKey(query, k, options);
// 预测缓存命中率
double predictedHitRate = hitPredictor.predict(query, k, options);
// 只缓存高命中率预期的查询
if (predictedHitRate > 0.1) { // 10%以上命中率
CachedSearchResult cachedResult = new CachedSearchResult(
query, new ArrayList<>(results), System.currentTimeMillis());
cache.put(queryKey, cachedResult);
}
}
/**
* 查找相似查询
*/
private Optional<CachedSearchResult> findSimilarQuery(Vector query, int k, SearchOptions options) {
double maxSimilarity = 0.0;
CachedSearchResult mostSimilar = null;
for (CachedSearchResult cached : cache.asMap().values()) {
double similarity = similarityCalculator.calculate(query, cached.getQuery());
if (similarity > 0.95 && similarity > maxSimilarity) { // 95%以上相似度
maxSimilarity = similarity;
mostSimilar = cached;
}
}
return Optional.ofNullable(mostSimilar);
}
/**
* 调整相似查询的结果
*/
private List<SearchResult> adjustSimilarResults(List<SearchResult> originalResults,
Vector newQuery, Vector originalQuery) {
// 重新计算距离并排序
return originalResults.stream()
.map(result -> new SearchResult(
result.getId(),
calculateDistance(newQuery, result.getVector()),
result.getVector()
))
.sorted(Comparator.comparing(SearchResult::getDistance))
.collect(Collectors.toList());
}
/**
* 缓存移除监听器
*/
private void onRemoval(String key, CachedSearchResult value, RemovalCause cause) {
logger.debug("Cache entry removed: key={}, cause={}", key, cause);
// 更新命中率预测模型
hitPredictor.updateModel(value, cause);
}
}
/**
* 缓存的搜索结果
*/
public class CachedSearchResult {
private final Vector query;
private final List<SearchResult> results;
private final long timestamp;
private volatile int hitCount = 0;
public CachedSearchResult(Vector query, List<SearchResult> results, long timestamp) {
this.query = query;
this.results = results;
this.timestamp = timestamp;
}
public void recordHit() {
hitCount++;
}
// Getters
public Vector getQuery() { return query; }
public List<SearchResult> getResults() { return results; }
public long getTimestamp() { return timestamp; }
public int getHitCount() { return hitCount; }
}
/**
* 预计算管理器
*/
public class PrecomputationManager {
private static final Logger logger = LoggerFactory.getLogger(PrecomputationManager.class);
private final HnswIndex index;
private final ScheduledExecutorService scheduler;
private final Map<String, PrecomputedData> precomputedCache;
private final PrecomputationStrategy strategy;
public PrecomputationManager(HnswIndex index, PrecomputationStrategy strategy) {
this.index = index;
this.strategy = strategy;
this.scheduler = Executors.newScheduledThreadPool(2);
this.precomputedCache = new ConcurrentHashMap<>();
startPrecomputation();
}
/**
* 开始预计算任务
*/
private void startPrecomputation() {
// 定期预计算热点数据
scheduler.scheduleWithFixedDelay(
this::precomputeHotData,
0, 60, TimeUnit.MINUTES
);
// 预计算中心点
scheduler.scheduleWithFixedDelay(
this::precomputeCentroids,
5, 120, TimeUnit.MINUTES
);
}
/**
* 预计算热点数据
*/
private void precomputeHotData() {
logger.info("Starting hot data precomputation...");
try {
// 获取热点查询模式
List<QueryPattern> hotPatterns = strategy.getHotQueryPatterns();
for (QueryPattern pattern : hotPatterns) {
precomputePattern(pattern);
}
logger.info("Hot data precomputation completed, processed {} patterns", hotPatterns.size());
} catch (Exception e) {
logger.error("Failed to precompute hot data", e);
}
}
/**
* 预计算查询模式
*/
private void precomputePattern(QueryPattern pattern) {
String patternKey = pattern.getKey();
// 检查是否已经预计算
if (precomputedCache.containsKey(patternKey)) {
return;
}
// 生成代表性查询向量
Vector representativeQuery = pattern.generateRepresentativeQuery();
// 执行搜索并缓存结果
List<SearchResult> results = index.search(
representativeQuery,
pattern.getK(),
pattern.getSearchOptions()
);
// 缓存预计算结果
PrecomputedData data = new PrecomputedData(
representativeQuery, results, System.currentTimeMillis());
precomputedCache.put(patternKey, data);
logger.debug("Precomputed pattern: {}", patternKey);
}
/**
* 预计算质心
*/
private void precomputeCentroids() {
logger.info("Starting centroid precomputation...");
try {
// 聚类分析
ClusterAnalyzer analyzer = new ClusterAnalyzer(index);
List<VectorCluster> clusters = analyzer.analyzeClusters(100); // 分100个簇
// 预计算每个簇的质心和最近邻
for (VectorCluster cluster : clusters) {
precomputeClusterData(cluster);
}
logger.info("Centroid precomputation completed, processed {} clusters", clusters.size());
} catch (Exception e) {
logger.error("Failed to precompute centroids", e);
}
}
/**
* 预计算簇数据
*/
private void precomputeClusterData(VectorCluster cluster) {
Vector centroid = cluster.getCentroid();
String clusterKey = "cluster_" + cluster.getId();
// 预计算质心的最近邻
List<SearchResult> nearestToCenter = index.search(centroid, 50);
// 缓存簇数据
ClusterData clusterData = new ClusterData(
cluster.getId(), centroid, nearestToCenter, cluster.getMembers());
precomputedCache.put(clusterKey, clusterData);
}
/**
* 获取预计算结果
*/
public Optional<List<SearchResult>> getPrecomputedResults(Vector query, int k, SearchOptions options) {
// 查找最匹配的预计算数据
PrecomputedData bestMatch = findBestMatch(query, k, options);
if (bestMatch != null) {
// 基于预计算结果进行调整
return Optional.of(adjustPrecomputedResults(bestMatch, query, k));
}
return Optional.empty();
}
/**
* 查找最佳匹配的预计算数据
*/
private PrecomputedData findBestMatch(Vector query, int k, SearchOptions options) {
double bestSimilarity = 0.0;
PrecomputedData bestMatch = null;
for (PrecomputedData data : precomputedCache.values()) {
if (data instanceof ClusterData) {
ClusterData clusterData = (ClusterData) data;
double similarity = calculateSimilarity(query, clusterData.getCentroid());
if (similarity > 0.8 && similarity > bestSimilarity) {
bestSimilarity = similarity;
bestMatch = data;
}
}
}
return bestMatch;
}
/**
* 调整预计算结果
*/
private List<SearchResult> adjustPrecomputedResults(PrecomputedData precomputed,
Vector query, int k) {
// 重新计算距离并选择前k个
return precomputed.getResults().stream()
.map(result -> new SearchResult(
result.getId(),
calculateDistance(query, result.getVector()),
result.getVector()
))
.sorted(Comparator.comparing(SearchResult::getDistance))
.limit(k)
.collect(Collectors.toList());
}
}
/**
* 索引预热管理器
*/
public class IndexWarmupManager {
private static final Logger logger = LoggerFactory.getLogger(IndexWarmupManager.class);
private final HnswIndex index;
private final WarmupStrategy strategy;
private final ExecutorService warmupExecutor;
public IndexWarmupManager(HnswIndex index, WarmupStrategy strategy) {
this.index = index;
this.strategy = strategy;
this.warmupExecutor = Executors.newFixedThreadPool(
Runtime.getRuntime().availableProcessors());
}
/**
* 执行完整的索引预热
*/
public void warmupIndex() {
logger.info("Starting index warmup...");
long startTime = System.currentTimeMillis();
List<CompletableFuture<Void>> warmupTasks = Arrays.asList(
CompletableFuture.runAsync(this::warmupConnections, warmupExecutor),
CompletableFuture.runAsync(this::warmupVectorData, warmupExecutor),
CompletableFuture.runAsync(this::warmupDistanceCalculations, warmupExecutor),
CompletableFuture.runAsync(this::warmupCaches, warmupExecutor)
);
// 等待所有预热任务完成
CompletableFuture.allOf(warmupTasks.toArray(new CompletableFuture[0]))
.join();
long duration = System.currentTimeMillis() - startTime;
logger.info("Index warmup completed in {} ms", duration);
}
/**
* 预热连接数据
*/
private void warmupConnections() {
logger.debug("Warming up connections...");
Collection<HnswNode> nodes = index.getAllNodes();
int accessCount = 0;
for (HnswNode node : nodes) {
// 访问节点的连接数据,触发缓存加载
for (int level = 0; level <= node.getLevel(); level++) {
Set<Long> connections = node.getConnections(level);
accessCount += connections.size();
}
}
logger.debug("Warmed up {} connections", accessCount);
}
/**
* 预热向量数据
*/
private void warmupVectorData() {
logger.debug("Warming up vector data...");
Collection<HnswNode> nodes = index.getAllNodes();
int vectorCount = 0;
for (HnswNode node : nodes) {
// 访问向量数据,触发缓存加载
Vector vector = node.getVector();
float[] data = vector.getData();
// 简单的内存访问以确保数据在缓存中
@SuppressWarnings("unused")
float sum = data[0] + data[data.length - 1];
vectorCount++;
}
logger.debug("Warmed up {} vectors", vectorCount);
}
/**
* 预热距离计算
*/
private void warmupDistanceCalculations() {
logger.debug("Warming up distance calculations...");
List<Vector> sampleVectors = strategy.getSampleVectors(100);
DistanceMetric metric = index.getDistanceMetric();
// 执行一些距离计算来预热JIT编译
for (int i = 0; i < sampleVectors.size(); i++) {
for (int j = i + 1; j < Math.min(i + 10, sampleVectors.size()); j++) {
metric.distance(
sampleVectors.get(i).getData(),
sampleVectors.get(j).getData()
);
}
}
logger.debug("Completed distance calculation warmup");
}
/**
* 预热缓存
*/
private void warmupCaches() {
logger.debug("Warming up caches...");
// 预热查询缓存
List<Vector> queryVectors = strategy.getSampleQueries(50);
for (Vector query : queryVectors) {
index.search(query, 10); // 执行小规模搜索预热缓存
}
logger.debug("Cache warmup completed");
}
/**
* 渐进式预热
*/
public void progressiveWarmup() {
logger.info("Starting progressive warmup...");
// 分阶段预热,避免系统负载过高
warmupPhase("Connections", this::warmupConnections, 30);
warmupPhase("Vector Data", this::warmupVectorData, 45);
warmupPhase("Distance Calculations", this::warmupDistanceCalculations, 15);
warmupPhase("Caches", this::warmupCaches, 10);
logger.info("Progressive warmup completed");
}
/**
* 执行预热阶段
*/
private void warmupPhase(String phaseName, Runnable warmupTask, long maxDurationSeconds) {
logger.debug("Starting warmup phase: {}", phaseName);
long startTime = System.currentTimeMillis();
Future<?> task = warmupExecutor.submit(warmupTask);
try {
task.get(maxDurationSeconds, TimeUnit.SECONDS);
long duration = System.currentTimeMillis() - startTime;
logger.debug("Warmup phase '{}' completed in {} ms", phaseName, duration);
} catch (TimeoutException e) {
logger.warn("Warmup phase '{}' timed out after {} seconds", phaseName, maxDurationSeconds);
task.cancel(true);
} catch (Exception e) {
logger.error("Warmup phase '{}' failed", phaseName, e);
}
}
}
/**
* 预热策略接口
*/
public interface WarmupStrategy {
List<Vector> getSampleVectors(int count);
List<Vector> getSampleQueries(int count);
List<QueryPattern> getHotQueryPatterns();
}
/**
* 默认预热策略
*/
public class DefaultWarmupStrategy implements WarmupStrategy {
private final HnswIndex index;
private final Random random = new Random(42);
public DefaultWarmupStrategy(HnswIndex index) {
this.index = index;
}
@Override
public List<Vector> getSampleVectors(int count) {
Collection<HnswNode> allNodes = index.getAllNodes();
List<HnswNode> nodeList = new ArrayList<>(allNodes);
return nodeList.stream()
.limit(count)
.map(HnswNode::getVector)
.collect(Collectors.toList());
}
@Override
public List<Vector> getSampleQueries(int count) {
// 生成基于现有向量的查询
List<Vector> sampleVectors = getSampleVectors(count);
return sampleVectors.stream()
.map(this::addNoise) // 添加噪声生成查询向量
.collect(Collectors.toList());
}
/**
* 添加噪声生成查询向量
*/
private Vector addNoise(Vector original) {
float[] originalData = original.getData();
float[] noisyData = new float[originalData.length];
for (int i = 0; i < originalData.length; i++) {
float noise = (random.nextFloat() - 0.5f) * 0.1f; // 10%噪声
noisyData[i] = originalData[i] + noise;
}
return new Vector(noisyData);
}
@Override
public List<QueryPattern> getHotQueryPatterns() {
// 返回常见的查询模式
return Arrays.asList(
new QueryPattern("small_k", 10, SearchOptions.defaultOptions()),
new QueryPattern("medium_k", 50, SearchOptions.defaultOptions()),
new QueryPattern("large_k", 100, SearchOptions.defaultOptions())
);
}
}
本章介绍了向量搜索引擎的高级特性与性能优化技术:
这些优化技术能够显著提升向量搜索引擎的性能和稳定性,为生产环境提供可靠的保障。
思考题: