70 lines
1.9 KiB
Python
70 lines
1.9 KiB
Python
import chromadb
|
|
from typing import Optional
|
|
import logging
|
|
import os
|
|
import hashlib
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 默认集合名称
|
|
DEFAULT_COLLECTION_NAME = "rag"
|
|
# 默认向量化模型名称
|
|
DEFAULT_MODEL_NAME = "all-MiniLM-L6-v2"
|
|
# 默认数据库存放路径
|
|
DEFAULT_DB_PATH = "./chroma_db"
|
|
|
|
# 定义全局mode
|
|
_model: Optional[SentenceTransformer] = None
|
|
# 定义全局客户端
|
|
_client: Optional[chromadb.PersistentClient] = None
|
|
|
|
|
|
def _get_mode():
|
|
global _model
|
|
if _model is None:
|
|
_model = SentenceTransformer(DEFAULT_MODEL_NAME)
|
|
return _model
|
|
|
|
|
|
def _get_client():
|
|
global _client
|
|
if _client is None:
|
|
_client = chromadb.PersistentClient(path=DEFAULT_DB_PATH)
|
|
return _client
|
|
|
|
|
|
def save_text_to_db(text: str, collection_name=DEFAULT_COLLECTION_NAME):
|
|
try:
|
|
if not text or not text.strip():
|
|
logger.warning("空文本,已跳过")
|
|
return ""
|
|
|
|
# 获取模型
|
|
mode = _get_mode()
|
|
# 获取客户端
|
|
client = _get_client()
|
|
# 创建集合
|
|
collection = client.get_or_create_collection(collection_name)
|
|
# 创建hash id
|
|
text_id = hashlib.md5(text.encode("utf-8")).hexdigest()
|
|
existing = collection.get(ids=[text_id])
|
|
if existing and existing.get("ids"):
|
|
logger.info(f"此文本已保存过,跳过保存,id={text_id}")
|
|
return text_id
|
|
# 生成文本的embedding模型处理结果 ndarray,通过tolist转为列表
|
|
embedding = mode.encode([text])[0].tolist()
|
|
|
|
# 添加到向量数据库中
|
|
collection.add(
|
|
documents=[text],
|
|
embeddings=[embedding],
|
|
ids=[text_id],
|
|
metadatas=[{"source": "document"}],
|
|
)
|
|
return text_id
|
|
except Exception as e:
|
|
logger.error(f"保存文本向量库失败{str(e)}")
|
|
raise
|