向量数据库需要支持完整的CRUD(Create, Read, Update, Delete)操作,同时保证操作的一致性、并发安全性和高性能。
package com.jvector.api;
import com.jvector.core.Vector;
import com.jvector.core.SearchResult;
import java.util.List;
import java.util.Optional;
import java.util.Map;
/**
* 向量数据库CRUD操作接口
*/
public interface VectorCrudOperations {
// === CREATE操作 ===
/**
* 添加单个向量
*/
long add(Vector vector);
/**
* 添加带ID的向量
*/
void add(long id, Vector vector);
/**
* 批量添加向量
*/
List<Long> addBatch(List<Vector> vectors);
/**
* 批量添加带ID的向量
*/
void addBatch(Map<Long, Vector> vectors);
// === READ操作 ===
/**
* 根据ID获取向量
*/
Optional<Vector> get(long id);
/**
* 批量获取向量
*/
Map<Long, Vector> getBatch(List<Long> ids);
/**
* 搜索最相似的K个向量
*/
List<SearchResult> search(Vector query, int k);
/**
* 高级搜索(带参数)
*/
List<SearchResult> search(Vector query, int k, SearchOptions options);
/**
* 范围搜索
*/
List<SearchResult> searchWithinDistance(Vector query, float maxDistance);
// === UPDATE操作 ===
/**
* 更新向量
*/
boolean update(long id, Vector newVector);
/**
* 批量更新向量
*/
Map<Long, Boolean> updateBatch(Map<Long, Vector> vectors);
// === DELETE操作 ===
/**
* 删除单个向量
*/
boolean delete(long id);
/**
* 批量删除向量
*/
Map<Long, Boolean> deleteBatch(List<Long> ids);
// === 统计信息 ===
/**
* 获取向量总数
*/
long size();
/**
* 检查向量是否存在
*/
boolean contains(long id);
/**
* 获取所有向量ID
*/
List<Long> getAllIds();
}
/**
* 向量添加操作实现
*/
public class VectorAddOperation {
private static final Logger logger = LoggerFactory.getLogger(VectorAddOperation.class);
private final HnswIndex index;
private final VectorValidator validator;
private final OperationAuditor auditor;
private final AtomicLong idGenerator;
public VectorAddOperation(HnswIndex index) {
this.index = index;
this.validator = new VectorValidator();
this.auditor = new OperationAuditor();
this.idGenerator = new AtomicLong(1);
}
/**
* 添加向量(自动生成ID)
*/
public long add(Vector vector) {
long operationId = auditor.startOperation("ADD", null);
try {
// 1. 输入验证
validator.validate(vector);
// 2. 生成唯一ID
long id = idGenerator.getAndIncrement();
// 3. 执行添加操作
return executeAdd(id, vector, operationId);
} catch (Exception e) {
auditor.recordError(operationId, e);
throw new VectorOperationException("Failed to add vector", e);
}
}
/**
* 添加向量(指定ID)
*/
public void add(long id, Vector vector) {
long operationId = auditor.startOperation("ADD", id);
try {
// 1. 输入验证
validator.validate(vector);
validator.validateId(id);
// 2. 检查ID冲突
if (index.contains(id)) {
throw new DuplicateIdException("Vector with id " + id + " already exists");
}
// 3. 执行添加操作
executeAdd(id, vector, operationId);
} catch (Exception e) {
auditor.recordError(operationId, e);
throw new VectorOperationException("Failed to add vector with id " + id, e);
}
}
/**
* 执行添加操作的核心逻辑
*/
private long executeAdd(long id, Vector vector, long operationId) {
// 记录操作开始
long startTime = System.nanoTime();
try {
// 添加到索引
index.add(id, vector);
// 记录成功
long duration = System.nanoTime() - startTime;
auditor.recordSuccess(operationId, duration);
logger.debug("Successfully added vector with id: {}, duration: {}ms",
id, duration / 1_000_000);
return id;
} catch (Exception e) {
// 回滚操作(如果需要)
rollbackAdd(id);
throw e;
}
}
/**
* 回滚添加操作
*/
private void rollbackAdd(long id) {
try {
index.remove(id);
logger.debug("Rolled back add operation for id: {}", id);
} catch (Exception e) {
logger.warn("Failed to rollback add operation for id: {}", id, e);
}
}
}
/**
* 批量添加操作优化实现
*/
public class BatchAddOperation {
private final HnswIndex index;
private final int batchSize;
private final ExecutorService executorService;
public BatchAddOperation(HnswIndex index, int batchSize) {
this.index = index;
this.batchSize = batchSize;
this.executorService = ForkJoinPool.commonPool();
}
/**
* 批量添加向量(并行优化)
*/
public List<Long> addBatch(List<Vector> vectors) {
if (vectors.isEmpty()) {
return Collections.emptyList();
}
List<Long> resultIds = new ArrayList<>(vectors.size());
List<CompletableFuture<Long>> futures = new ArrayList<>();
// 分批并行处理
for (int i = 0; i < vectors.size(); i += batchSize) {
int endIndex = Math.min(i + batchSize, vectors.size());
List<Vector> batch = vectors.subList(i, endIndex);
CompletableFuture<List<Long>> batchFuture = CompletableFuture
.supplyAsync(() -> processBatch(batch), executorService);
// 收集每个批次的结果
for (int j = 0; j < batch.size(); j++) {
final int batchIndex = j;
CompletableFuture<Long> itemFuture = batchFuture
.thenApply(batchResults -> batchResults.get(batchIndex));
futures.add(itemFuture);
}
}
// 等待所有操作完成
for (CompletableFuture<Long> future : futures) {
try {
resultIds.add(future.get());
} catch (Exception e) {
throw new RuntimeException("Batch add operation failed", e);
}
}
return resultIds;
}
/**
* 处理单个批次
*/
private List<Long> processBatch(List<Vector> batch) {
List<Long> batchIds = new ArrayList<>();
// 使用事务确保批次的原子性
try (Transaction tx = index.beginTransaction()) {
for (Vector vector : batch) {
long id = index.add(vector);
batchIds.add(id);
}
tx.commit();
} catch (Exception e) {
// 批次失败时的处理
logger.error("Batch processing failed, rolling back", e);
throw new RuntimeException("Batch add failed", e);
}
return batchIds;
}
/**
* 批量添加带预分配ID的向量
*/
public void addBatch(Map<Long, Vector> vectors) {
if (vectors.isEmpty()) {
return;
}
// 验证ID唯一性
validateUniqueIds(vectors.keySet());
// 分批处理
List<Map.Entry<Long, Vector>> entries = new ArrayList<>(vectors.entrySet());
List<CompletableFuture<Void>> futures = new ArrayList<>();
for (int i = 0; i < entries.size(); i += batchSize) {
int endIndex = Math.min(i + batchSize, entries.size());
List<Map.Entry<Long, Vector>> batch = entries.subList(i, endIndex);
CompletableFuture<Void> future = CompletableFuture
.runAsync(() -> processBatchWithIds(batch), executorService);
futures.add(future);
}
// 等待所有批次完成
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
}
/**
* 处理带ID的批次
*/
private void processBatchWithIds(List<Map.Entry<Long, Vector>> batch) {
try (Transaction tx = index.beginTransaction()) {
for (Map.Entry<Long, Vector> entry : batch) {
index.add(entry.getKey(), entry.getValue());
}
tx.commit();
}
}
/**
* 验证ID唯一性
*/
private void validateUniqueIds(Set<Long> ids) {
for (Long id : ids) {
if (index.contains(id)) {
throw new DuplicateIdException("Vector with id " + id + " already exists");
}
}
}
}
/**
* 向量查询操作实现
*/
public class VectorQueryOperation {
private final HnswIndex index;
private final QueryCache queryCache;
private final QueryOptimizer optimizer;
public VectorQueryOperation(HnswIndex index) {
this.index = index;
this.queryCache = new QueryCache(1000); // 缓存1000个查询结果
this.optimizer = new QueryOptimizer();
}
/**
* 根据ID获取向量
*/
public Optional<Vector> get(long id) {
try {
HnswNode node = index.getNode(id);
return node != null ? Optional.of(node.getVector()) : Optional.empty();
} catch (Exception e) {
logger.error("Failed to get vector with id: {}", id, e);
return Optional.empty();
}
}
/**
* 批量获取向量
*/
public Map<Long, Vector> getBatch(List<Long> ids) {
Map<Long, Vector> result = new HashMap<>();
// 并行获取向量
ids.parallelStream().forEach(id -> {
Optional<Vector> vector = get(id);
if (vector.isPresent()) {
synchronized (result) {
result.put(id, vector.get());
}
}
});
return result;
}
/**
* 基础相似性搜索
*/
public List<SearchResult> search(Vector query, int k) {
return search(query, k, SearchOptions.defaultOptions());
}
/**
* 高级搜索(带选项)
*/
public List<SearchResult> search(Vector query, int k, SearchOptions options) {
// 生成查询缓存键
String cacheKey = generateCacheKey(query, k, options);
// 检查缓存
if (options.isEnableCache()) {
List<SearchResult> cachedResult = queryCache.get(cacheKey);
if (cachedResult != null) {
return cachedResult;
}
}
// 优化查询参数
SearchParameters params = optimizer.optimize(query, k, options, index);
// 执行搜索
List<SearchResult> results = executeSearch(query, k, params, options);
// 缓存结果
if (options.isEnableCache() && results != null) {
queryCache.put(cacheKey, results);
}
return results;
}
/**
* 执行搜索的核心逻辑
*/
private List<SearchResult> executeSearch(Vector query, int k,
SearchParameters params, SearchOptions options) {
long startTime = System.nanoTime();
try {
// 根据选项选择搜索策略
List<SearchResult> results;
if (options.isExactSearch()) {
results = index.exactSearch(query, k);
} else {
results = index.search(query, k, params.getEf());
}
// 后处理
results = postProcessResults(results, options);
long duration = System.nanoTime() - startTime;
logger.debug("Search completed in {}ms, returned {} results",
duration / 1_000_000, results.size());
return results;
} catch (Exception e) {
logger.error("Search operation failed", e);
throw new VectorOperationException("Search failed", e);
}
}
/**
* 结果后处理
*/
private List<SearchResult> postProcessResults(List<SearchResult> results, SearchOptions options) {
// 距离阈值过滤
if (options.getMaxDistance() > 0) {
results = results.stream()
.filter(r -> r.getDistance() <= options.getMaxDistance())
.collect(Collectors.toList());
}
// ID过滤
if (options.getExcludeIds() != null && !options.getExcludeIds().isEmpty()) {
Set<Long> excludeSet = new HashSet<>(options.getExcludeIds());
results = results.stream()
.filter(r -> !excludeSet.contains(r.getId()))
.collect(Collectors.toList());
}
// 结果去重
if (options.isDeduplication()) {
results = deduplicateResults(results);
}
return results;
}
/**
* 范围搜索实现
*/
public List<SearchResult> searchWithinDistance(Vector query, float maxDistance) {
SearchOptions options = SearchOptions.builder()
.maxDistance(maxDistance)
.enableCache(false) // 范围搜索通常不缓存
.build();
// 使用较大的k值进行搜索,然后按距离过滤
int k = Math.min(1000, (int) index.size());
List<SearchResult> results = search(query, k, options);
return results.stream()
.filter(r -> r.getDistance() <= maxDistance)
.collect(Collectors.toList());
}
}
/**
* 搜索选项配置
*/
public class SearchOptions {
private boolean exactSearch = false;
private boolean enableCache = true;
private boolean deduplication = false;
private float maxDistance = -1;
private List<Long> excludeIds = null;
private Map<String, Object> customParams = new HashMap<>();
// 建造者模式
public static Builder builder() {
return new Builder();
}
public static SearchOptions defaultOptions() {
return new SearchOptions();
}
public static class Builder {
private SearchOptions options = new SearchOptions();
public Builder exactSearch(boolean exact) {
options.exactSearch = exact;
return this;
}
public Builder enableCache(boolean enable) {
options.enableCache = enable;
return this;
}
public Builder deduplication(boolean enable) {
options.deduplication = enable;
return this;
}
public Builder maxDistance(float distance) {
options.maxDistance = distance;
return this;
}
public Builder excludeIds(List<Long> ids) {
options.excludeIds = ids;
return this;
}
public Builder customParam(String key, Object value) {
options.customParams.put(key, value);
return this;
}
public SearchOptions build() {
return options;
}
}
// Getters
public boolean isExactSearch() { return exactSearch; }
public boolean isEnableCache() { return enableCache; }
public boolean isDeduplication() { return deduplication; }
public float getMaxDistance() { return maxDistance; }
public List<Long> getExcludeIds() { return excludeIds; }
public Map<String, Object> getCustomParams() { return customParams; }
}
/**
* 查询结果缓存
*/
public class QueryCache {
private final Map<String, CacheEntry> cache = new ConcurrentHashMap<>();
private final int maxSize;
private final long ttlMs;
public QueryCache(int maxSize) {
this(maxSize, 300_000); // 默认5分钟TTL
}
public QueryCache(int maxSize, long ttlMs) {
this.maxSize = maxSize;
this.ttlMs = ttlMs;
}
public List<SearchResult> get(String key) {
CacheEntry entry = cache.get(key);
if (entry == null || entry.isExpired()) {
cache.remove(key);
return null;
}
entry.updateAccessTime();
return entry.getResults();
}
public void put(String key, List<SearchResult> results) {
// 缓存大小控制
if (cache.size() >= maxSize) {
evictLRU();
}
cache.put(key, new CacheEntry(results, System.currentTimeMillis()));
}
/**
* LRU驱逐策略
*/
private void evictLRU() {
String lruKey = null;
long lruTime = Long.MAX_VALUE;
for (Map.Entry<String, CacheEntry> entry : cache.entrySet()) {
if (entry.getValue().getLastAccessTime() < lruTime) {
lruTime = entry.getValue().getLastAccessTime();
lruKey = entry.getKey();
}
}
if (lruKey != null) {
cache.remove(lruKey);
}
}
/**
* 缓存条目
*/
private class CacheEntry {
private final List<SearchResult> results;
private final long createTime;
private volatile long lastAccessTime;
public CacheEntry(List<SearchResult> results, long createTime) {
this.results = new ArrayList<>(results); // 防御性拷贝
this.createTime = createTime;
this.lastAccessTime = createTime;
}
public boolean isExpired() {
return System.currentTimeMillis() - createTime > ttlMs;
}
public void updateAccessTime() {
this.lastAccessTime = System.currentTimeMillis();
}
public List<SearchResult> getResults() { return results; }
public long getLastAccessTime() { return lastAccessTime; }
}
}
/**
* 向量更新操作实现
*/
public class VectorUpdateOperation {
private final HnswIndex index;
private final OperationAuditor auditor;
private final UpdateStrategy strategy;
public VectorUpdateOperation(HnswIndex index, UpdateStrategy strategy) {
this.index = index;
this.auditor = new OperationAuditor();
this.strategy = strategy;
}
/**
* 更新单个向量
*/
public boolean update(long id, Vector newVector) {
long operationId = auditor.startOperation("UPDATE", id);
try {
// 1. 验证输入
if (!index.contains(id)) {
auditor.recordError(operationId, new NotFoundException("Vector not found: " + id));
return false;
}
// 2. 获取原向量(用于回滚)
Vector oldVector = index.getVector(id);
// 3. 执行更新策略
boolean success = strategy.update(index, id, oldVector, newVector);
if (success) {
auditor.recordSuccess(operationId, 0);
// 4. 清理相关缓存
invalidateRelatedCaches(id);
} else {
auditor.recordError(operationId, new UpdateFailedException("Update failed for id: " + id));
}
return success;
} catch (Exception e) {
auditor.recordError(operationId, e);
throw new VectorOperationException("Failed to update vector " + id, e);
}
}
/**
* 批量更新向量
*/
public Map<Long, Boolean> updateBatch(Map<Long, Vector> vectors) {
Map<Long, Boolean> results = new ConcurrentHashMap<>();
// 并行执行更新
vectors.entrySet().parallelStream().forEach(entry -> {
boolean success = update(entry.getKey(), entry.getValue());
results.put(entry.getKey(), success);
});
return results;
}
/**
* 失效相关缓存
*/
private void invalidateRelatedCaches(long id) {
// 清理查询缓存中包含该向量的结果
// 这是一个简化实现,实际中可能需要更复杂的缓存失效策略
}
}
/**
* 更新策略接口
*/
public interface UpdateStrategy {
boolean update(HnswIndex index, long id, Vector oldVector, Vector newVector);
}
/**
* 删除后重新添加的更新策略
*/
public class DeleteAndAddUpdateStrategy implements UpdateStrategy {
@Override
public boolean update(HnswIndex index, long id, Vector oldVector, Vector newVector) {
try (Transaction tx = index.beginTransaction()) {
// 1. 删除原向量
boolean deleted = index.remove(id);
if (!deleted) {
tx.rollback();
return false;
}
// 2. 添加新向量
index.add(id, newVector);
tx.commit();
return true;
} catch (Exception e) {
logger.error("Delete-and-add update failed for id: {}", id, e);
return false;
}
}
}
/**
* 原地更新策略(如果索引支持)
*/
public class InPlaceUpdateStrategy implements UpdateStrategy {
@Override
public boolean update(HnswIndex index, long id, Vector oldVector, Vector newVector) {
try {
// 直接更新节点的向量数据
HnswNode node = index.getNode(id);
if (node == null) {
return false;
}
// 更新向量数据
node.updateVector(newVector);
// 如果向量变化很大,可能需要重新建立连接
if (shouldReconnect(oldVector, newVector)) {
reconnectNode(index, node);
}
return true;
} catch (Exception e) {
logger.error("In-place update failed for id: {}", id, e);
return false;
}
}
private boolean shouldReconnect(Vector oldVector, Vector newVector) {
// 计算向量变化程度
float similarity = cosineSimilarity(oldVector, newVector);
return similarity < 0.8; // 如果相似度低于0.8,重新连接
}
private void reconnectNode(HnswIndex index, HnswNode node) {
// 重新计算并建立连接
// 这是一个复杂的操作,需要考虑图的一致性
}
}
/**
* 向量删除操作实现
*/
public class VectorDeleteOperation {
private final HnswIndex index;
private final OperationAuditor auditor;
private final DeletionStrategy strategy;
public VectorDeleteOperation(HnswIndex index, DeletionStrategy strategy) {
this.index = index;
this.auditor = new OperationAuditor();
this.strategy = strategy;
}
/**
* 删除单个向量
*/
public boolean delete(long id) {
long operationId = auditor.startOperation("DELETE", id);
try {
// 1. 检查向量是否存在
if (!index.contains(id)) {
auditor.recordError(operationId, new NotFoundException("Vector not found: " + id));
return false;
}
// 2. 执行删除策略
boolean success = strategy.delete(index, id);
if (success) {
auditor.recordSuccess(operationId, 0);
// 3. 清理相关缓存和引用
cleanupAfterDeletion(id);
}
return success;
} catch (Exception e) {
auditor.recordError(operationId, e);
throw new VectorOperationException("Failed to delete vector " + id, e);
}
}
/**
* 批量删除向量
*/
public Map<Long, Boolean> deleteBatch(List<Long> ids) {
Map<Long, Boolean> results = new ConcurrentHashMap<>();
// 验证所有ID是否存在
List<Long> existingIds = ids.stream()
.filter(index::contains)
.collect(Collectors.toList());
// 并行删除
existingIds.parallelStream().forEach(id -> {
boolean success = delete(id);
results.put(id, success);
});
// 标记不存在的ID
ids.stream()
.filter(id -> !existingIds.contains(id))
.forEach(id -> results.put(id, false));
return results;
}
/**
* 删除后清理工作
*/
private void cleanupAfterDeletion(long id) {
// 清理查询缓存
// 更新统计信息
// 其他清理工作
}
}
/**
* 删除策略接口
*/
public interface DeletionStrategy {
boolean delete(HnswIndex index, long id);
}
/**
* 立即删除策略
*/
public class ImmediateDeletionStrategy implements DeletionStrategy {
@Override
public boolean delete(HnswIndex index, long id) {
try {
return index.remove(id);
} catch (Exception e) {
logger.error("Immediate deletion failed for id: {}", id, e);
return false;
}
}
}
/**
* 延迟删除策略(标记删除)
*/
public class LazyDeletionStrategy implements DeletionStrategy {
private final Set<Long> deletedIds = ConcurrentHashMap.newKeySet();
private final ScheduledExecutorService scheduler =
Executors.newSingleThreadScheduledExecutor();
public LazyDeletionStrategy() {
// 定期清理被标记删除的向量
scheduler.scheduleWithFixedDelay(this::performActualDeletion,
60, 60, TimeUnit.SECONDS);
}
@Override
public boolean delete(HnswIndex index, long id) {
// 只是标记为删除,不立即从索引中移除
deletedIds.add(id);
return true;
}
/**
* 执行实际的删除操作
*/
private void performActualDeletion() {
if (deletedIds.isEmpty()) {
return;
}
Set<Long> toDelete = new HashSet<>(deletedIds);
deletedIds.clear();
for (Long id : toDelete) {
try {
index.remove(id);
} catch (Exception e) {
logger.warn("Failed to delete vector {}", id, e);
// 删除失败的ID可以重新加入待删除集合
deletedIds.add(id);
}
}
}
/**
* 检查ID是否被标记删除
*/
public boolean isDeleted(long id) {
return deletedIds.contains(id);
}
}
本章详细介绍了向量数据库的完整CRUD操作实现:
这些CRUD操作的实现为向量数据库提供了完整的数据管理能力,支持高并发、高性能的向量数据操作。
思考题: