import os from typing import Optional, List import logging import chromadb from sentence_transformers import SentenceTransformer from llm import get_doubao_llm logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" ) logger = logging.getLogger(__name__) # 默认集合的名称 DEFAULT_COLLECTION_NAME = "rag_system" # 返回几条数据 DEFAULT_N_RESULTS = 2 # 默认向量化模型的名称 DEFAULT_MODEL_NAME = "all-MiniLM-L6-v2" # 定义向量模型的全局变量 _mode: Optional[SentenceTransformer] = None # 定义chromadb客户端 _client: Optional[chromadb.PersistentClient] = None _collection: Optional[chromadb.Collection] = None # 默认数据库存放路径 DEFAULT_DB_PATH = "./chroma_db" def _get_model(): global _mode if _mode is None: _mode = SentenceTransformer(DEFAULT_MODEL_NAME) return _mode def _get_client(): global _client if _client is None: _client = chromadb.PersistentClient(path=DEFAULT_DB_PATH) return _client def get_query_embedding(query: str) -> List[float]: model = _get_model() embedding = model.encode([query])[0].tolist() return embedding def _get_collection(collection_name: str = DEFAULT_COLLECTION_NAME): global _collection if _collection is None: client = _get_client() _collection = client.get_or_create_collection(collection_name) return _collection def retrieve_relate_chunks( query_embedding: List[float], n_results: int = DEFAULT_N_RESULTS, collection_name: str = DEFAULT_COLLECTION_NAME, ): try: collection = _get_collection(collection_name) # print(n_results) # 去指定集合查找相似度检索,找到数据 results = collection.query( query_embeddings=[query_embedding], n_results=n_results ) related_chunks = results.get("documents") if not related_chunks or not related_chunks[0]: raise ValueError("未找到相关内容") return related_chunks[0] except Exception as e: logger.error(f"向量检索失败:{str(e)}") raise def query_rag( query: str, n_results: int = DEFAULT_N_RESULTS, collection_name: str = DEFAULT_COLLECTION_NAME, ): """ 查询函数: query:用户查询的问题 n_results:检索数量 collection_name: 集合名字 """ # 1. 将查询问题转为向量 query_embedding = get_query_embedding(query) # print(query_embedding) # 基于查询向量做检索 related_chunks = retrieve_relate_chunks( query_embedding, n_results, collection_name=collection_name ) # print("related_chunks", related_chunks) content = "\n".join(related_chunks) prompt = f""" 已知信息:{content} 请根据上述内容回答用户问题:{query} """ print(prompt) answer = get_doubao_llm(prompt) return answer query = "西游记是谁写的" try: answer = query_rag(query, n_results=1) print(f"答案:", answer) except ValueError as e: print(f"错误{e}") except Exception as e: print(f"错误{e}")