Files
03Rag/rag/vectorstore.py
T
heyong.fu a17c65c4bc feat: rag
2026-05-06 11:35:10 +08:00

70 lines
1.9 KiB
Python

import chromadb
from typing import Optional
import logging
import os
import hashlib
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
# 默认集合名称
DEFAULT_COLLECTION_NAME = "rag"
# 默认向量化模型名称
DEFAULT_MODEL_NAME = "all-MiniLM-L6-v2"
# 默认数据库存放路径
DEFAULT_DB_PATH = "./chroma_db"
# 定义全局mode
_model: Optional[SentenceTransformer] = None
# 定义全局客户端
_client: Optional[chromadb.PersistentClient] = None
def _get_mode():
global _model
if _model is None:
_model = SentenceTransformer(DEFAULT_MODEL_NAME)
return _model
def _get_client():
global _client
if _client is None:
_client = chromadb.PersistentClient(path=DEFAULT_DB_PATH)
return _client
def save_text_to_db(text: str, collection_name=DEFAULT_COLLECTION_NAME):
try:
if not text or not text.strip():
logger.warning("空文本,已跳过")
return ""
# 获取模型
mode = _get_mode()
# 获取客户端
client = _get_client()
# 创建集合
collection = client.get_or_create_collection(collection_name)
# 创建hash id
text_id = hashlib.md5(text.encode("utf-8")).hexdigest()
existing = collection.get(ids=[text_id])
if existing and existing.get("ids"):
logger.info(f"此文本已保存过,跳过保存,id={text_id}")
return text_id
# 生成文本的embedding模型处理结果 ndarray,通过tolist转为列表
embedding = mode.encode([text])[0].tolist()
# 添加到向量数据库中
collection.add(
documents=[text],
embeddings=[embedding],
ids=[text_id],
metadatas=[{"source": "document"}],
)
return text_id
except Exception as e:
logger.error(f"保存文本向量库失败{str(e)}")
raise