feat: rag
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
import chromadb
|
||||
from typing import Optional
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认集合名称
|
||||
DEFAULT_COLLECTION_NAME = "rag"
|
||||
# 默认向量化模型名称
|
||||
DEFAULT_MODEL_NAME = "all-MiniLM-L6-v2"
|
||||
# 默认数据库存放路径
|
||||
DEFAULT_DB_PATH = "./chroma_db"
|
||||
|
||||
# 定义全局mode
|
||||
_model: Optional[SentenceTransformer] = None
|
||||
# 定义全局客户端
|
||||
_client: Optional[chromadb.PersistentClient] = None
|
||||
|
||||
|
||||
def _get_mode():
|
||||
global _model
|
||||
if _model is None:
|
||||
_model = SentenceTransformer(DEFAULT_MODEL_NAME)
|
||||
return _model
|
||||
|
||||
|
||||
def _get_client():
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = chromadb.PersistentClient(path=DEFAULT_DB_PATH)
|
||||
return _client
|
||||
|
||||
|
||||
def save_text_to_db(text: str, collection_name=DEFAULT_COLLECTION_NAME):
|
||||
try:
|
||||
if not text or not text.strip():
|
||||
logger.warning("空文本,已跳过")
|
||||
return ""
|
||||
|
||||
# 获取模型
|
||||
mode = _get_mode()
|
||||
# 获取客户端
|
||||
client = _get_client()
|
||||
# 创建集合
|
||||
collection = client.get_or_create_collection(collection_name)
|
||||
# 创建hash id
|
||||
text_id = hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
existing = collection.get(ids=[text_id])
|
||||
if existing and existing.get("ids"):
|
||||
logger.info(f"此文本已保存过,跳过保存,id={text_id}")
|
||||
return text_id
|
||||
# 生成文本的embedding模型处理结果 ndarray,通过tolist转为列表
|
||||
embedding = mode.encode([text])[0].tolist()
|
||||
|
||||
# 添加到向量数据库中
|
||||
collection.add(
|
||||
documents=[text],
|
||||
embeddings=[embedding],
|
||||
ids=[text_id],
|
||||
metadatas=[{"source": "document"}],
|
||||
)
|
||||
return text_id
|
||||
except Exception as e:
|
||||
logger.error(f"保存文本向量库失败{str(e)}")
|
||||
raise
|
||||
Reference in New Issue
Block a user