feat: rag
This commit is contained in:
@@ -0,0 +1,2 @@
|
|||||||
|
DATABASE_URL=https://ark.cn-beijing.volces.com/api/v3
|
||||||
|
API_KEY=79b39c58-56db-4d8a-a8f8-84b95fca08db
|
||||||
+10
@@ -0,0 +1,10 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
3.12
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
# 读取pdf文件
|
||||||
|
import fitz
|
||||||
|
|
||||||
|
|
||||||
|
def extract_pfd_text(pdf_path):
|
||||||
|
"""
|
||||||
|
提取pdf文件中的内容
|
||||||
|
参数:pdf_path(str):pdf文件路径
|
||||||
|
返回:
|
||||||
|
str:合并后所有页的文本
|
||||||
|
"""
|
||||||
|
# 打开pdf文件
|
||||||
|
pdf = fitz.open(pdf_path)
|
||||||
|
# 存储每一页的信息
|
||||||
|
text_list = []
|
||||||
|
# 遍历pdf中的每一页
|
||||||
|
for page in pdf:
|
||||||
|
text_list.append(page.get_text("text"))
|
||||||
|
# 所有内容合并成一个字符串
|
||||||
|
all_text = "/n".join(text_list)
|
||||||
|
return all_text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pdf_path = "example/example.pdf"
|
||||||
|
result_text = extract_pfd_text(pdf_path)
|
||||||
|
print(result_text)
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
# 读取word文件
|
||||||
|
from docx import Document
|
||||||
|
|
||||||
|
|
||||||
|
# 定义函数
|
||||||
|
def extract_text_from_word(file_path):
|
||||||
|
"""
|
||||||
|
从word文档中提取所有段落,并以字符串返回
|
||||||
|
param file_path:文件地址
|
||||||
|
return: 返回文本内容字符串
|
||||||
|
|
||||||
|
"""
|
||||||
|
# 加载文件
|
||||||
|
doc = Document(file_path)
|
||||||
|
text = "\n".join([para.text for para in doc.paragraphs])
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
file_path = "example/example.docx"
|
||||||
|
result = extract_text_from_word(file_path)
|
||||||
|
print(result)
|
||||||
+28
@@ -0,0 +1,28 @@
|
|||||||
|
import openpyxl
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_from_excel(file_path):
|
||||||
|
"""
|
||||||
|
从Excel文件中提取所有单元格内容为文本,并以字符串返回。
|
||||||
|
:param file_path: Excel文件路径
|
||||||
|
:return: 文本内容字符串
|
||||||
|
"""
|
||||||
|
# 加载Excel表格
|
||||||
|
wb = openpyxl.load_workbook(file_path)
|
||||||
|
# 获取活动的工作表小
|
||||||
|
ws = wb.active
|
||||||
|
# 初始化用于存储每一行的文本列表
|
||||||
|
rows = []
|
||||||
|
# 遍历工作区的每一行,values_only = True 标识只获取单元格的值
|
||||||
|
for row in ws.iter_rows(values_only=True):
|
||||||
|
# 将每一行的单元格的数据转为字符串,并用制表符分割,如果为空返回空字符串
|
||||||
|
rows.append("\t".join([str(cell) if cell is not None else "" for cell in row]))
|
||||||
|
all_text = "\n".join(rows)
|
||||||
|
|
||||||
|
return all_text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
file_path = "example/example.xlsx"
|
||||||
|
result = extract_text_from_excel(file_path)
|
||||||
|
print(result)
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
# 读取ppt文件
|
||||||
|
from pptx import Presentation
|
||||||
|
|
||||||
|
# 定义函数,提取ppt文件中的所有文本内容
|
||||||
|
|
||||||
|
|
||||||
|
def extract_ppt_text(file_path):
|
||||||
|
"""
|
||||||
|
提取PPT文件中的所有文本内容,并以字符串返回。
|
||||||
|
:param file_path: PPT文件路径
|
||||||
|
:return: 所有文本内容(以换行符分隔)
|
||||||
|
"""
|
||||||
|
# 加载ppt文件
|
||||||
|
ppt = Presentation(file_path)
|
||||||
|
# 初始化用于存储ppt文本的列表
|
||||||
|
text_list = []
|
||||||
|
# 遍历PPT中的每一页幻灯片
|
||||||
|
for slide in ppt.slides:
|
||||||
|
# 遍历幻灯片中的每一个形状
|
||||||
|
for shape in slide.shapes:
|
||||||
|
# 判断该形状是否有text属性(即是否包含文本)
|
||||||
|
if hasattr(shape, "text"):
|
||||||
|
# 如果有文本添加到text_list中
|
||||||
|
text_list.append(shape.text)
|
||||||
|
|
||||||
|
all_text = "\n".join(text_list)
|
||||||
|
return all_text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
file_path = "example/example.pptx"
|
||||||
|
result = extract_ppt_text(file_path)
|
||||||
|
print(result)
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
# 读取html文件
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_html(file_path):
|
||||||
|
"""
|
||||||
|
从指定HTML文件中提取所有文本内容
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path (str): HTML文件路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 提取的文本内容
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
# 读取整个html文件内容字符串
|
||||||
|
html = f.read()
|
||||||
|
# 使用BeautifulSoup解析html内容
|
||||||
|
soup = BeautifulSoup(html, "html.parser")
|
||||||
|
# 提取所有文本内容,使用换行符分割
|
||||||
|
text = soup.get_text(separator="\n")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
file_path = "example/example.html"
|
||||||
|
result = extract_text_html(file_path)
|
||||||
|
print(result)
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
# 读取json文件
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def read_json(file_path):
|
||||||
|
"""
|
||||||
|
读取指定JSON文件并以格式化字符串打印内容
|
||||||
|
:param file_path: JSON文件路径
|
||||||
|
"""
|
||||||
|
# 以utf-8的格式打开指定的json文件
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
# 使用json.load读取文件内容为python对象
|
||||||
|
data = json.load(f)
|
||||||
|
# 使用json.dumps将python对象格式化为带有缩进的字符串,确保中文正常显示
|
||||||
|
text = json.dumps(data, ensure_ascii=False, indent=2)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
file_path = "example/example.json"
|
||||||
|
result = read_json(file_path)
|
||||||
|
print(result)
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
# 读取xml文件格式
|
||||||
|
|
||||||
|
from lxml import etree
|
||||||
|
|
||||||
|
|
||||||
|
def extract_xml_text(file_path):
|
||||||
|
"""
|
||||||
|
读取XML文件并提取所有文本内容
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path (str): XML文件路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 提取的所有文本内容
|
||||||
|
"""
|
||||||
|
# 以utf-8格式打开文件
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
# 读取xml文件的全部字符串
|
||||||
|
xml = f.read()
|
||||||
|
# 将字符串形式的xml内容解析为xms树结构
|
||||||
|
root = etree.fromstring(xml.encode("utf-8"))
|
||||||
|
# 遍历xml树,提取所有文本内容,并用空格链接
|
||||||
|
text = " ".join(root.itertext())
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
file_path = "example/example.xml"
|
||||||
|
result = extract_xml_text(file_path)
|
||||||
|
print(result)
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
# 去读csv文件
|
||||||
|
import csv
|
||||||
|
|
||||||
|
|
||||||
|
def read_csv_to_text(filename):
|
||||||
|
"""
|
||||||
|
读取CSV文件内容,并将每行用逗号连接,所有行用换行符拼接成一个字符串返回。
|
||||||
|
"""
|
||||||
|
# 以uft-8格式打开文件
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
# 创建csv.reader对象,按行读取csv内容
|
||||||
|
reader = csv.reader(f)
|
||||||
|
# 对每一行,用逗号链接各列,生成字符串列表
|
||||||
|
rows = [", ".join(row) for row in reader]
|
||||||
|
|
||||||
|
# 拼接所有文本
|
||||||
|
all_text = "\n".join(rows)
|
||||||
|
return all_text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
filename = "example/example.csv"
|
||||||
|
result = read_csv_to_text(filename)
|
||||||
|
print(result)
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# 读取纯文本内容
|
||||||
|
|
||||||
|
|
||||||
|
def read_text_file(filename):
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
text = f.read()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
filename = "example/example.txt"
|
||||||
|
result = read_text_file(filename)
|
||||||
|
print(result)
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# 读取纯文本内容
|
||||||
|
|
||||||
|
|
||||||
|
def read_markdown_file(filename):
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
text = f.read()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
filename = "example/example.md"
|
||||||
|
result = read_markdown_file(filename)
|
||||||
|
print(result)
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
from langchain_text_splitters import CharacterTextSplitter
|
||||||
|
|
||||||
|
|
||||||
|
# 创建字符分割器实例,设置每个块最大长度为100个字符,不重叠,使用空字符串进行分割
|
||||||
|
text_splitters = CharacterTextSplitter(
|
||||||
|
chunk_size=100, # 每个块的最大长度是100个字符
|
||||||
|
chunk_overlap=0, # 块之间不重叠
|
||||||
|
separator="", # 使用空白字符串作为分隔符
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建一个长文本
|
||||||
|
document = f"""{"1"*100}{"2"*100}{"3"*100}"""
|
||||||
|
|
||||||
|
# 使用分割器split_text方法,将原始文本切割成若干个字块
|
||||||
|
texts = text_splitters.split_text(document)
|
||||||
|
|
||||||
|
# 打印原始文本长度
|
||||||
|
print(f"原文长度{len(document)}")
|
||||||
|
# 打印分割后的块的数量
|
||||||
|
print(f"分割为{texts}个块")
|
||||||
|
|
||||||
|
for i, text in enumerate(texts, 1):
|
||||||
|
print(f"\n块{i}({len(text)}字符):{repr(text)}")
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
# RecursiveCharacterTextSplitter 是 LangChain 中最常用的文本分割器,它实现了基于文本结构的分割策略。对于大多数应用场景,这是推荐的默认选择。
|
||||||
|
|
||||||
|
# 为什么推荐?
|
||||||
|
|
||||||
|
# 在保持上下文完整性和管理块大小之间取得了良好的平衡
|
||||||
|
# 开箱即用,默认配置就能很好地工作
|
||||||
|
# 只有在需要针对特定应用进行微调时才需要调整参数
|
||||||
|
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
# 创建递归文本分割器对象,指定参数
|
||||||
|
# chunk_size 表示每块最大允许的字符数100
|
||||||
|
# chunk_overlap 表示块与块之间没有重叠(重叠字符数0)
|
||||||
|
text_splitters = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=0)
|
||||||
|
|
||||||
|
|
||||||
|
document = f"""{"1"*100}\n{"2"*99}\n\n{"3"*99}\n{"4"*99}"""
|
||||||
|
|
||||||
|
# 使用文本分割器的split_text 方法将document进行分割成多个字符串的块
|
||||||
|
texts = text_splitters.split_text(document)
|
||||||
|
print(f"共分割出{len(texts)}个块")
|
||||||
|
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
print(f"\n块{i}({len(text)}字符):{repr(text)}")
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# 说明:验证 Sentence Transformers 是否安装成功
|
||||||
|
|
||||||
|
# 说明:导入 SentenceTransformer 类
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
# 说明:尝试加载一个轻量级模型进行测试
|
||||||
|
# "all-MiniLM-L6-v2" 是一个小型的通用模型,适合快速测试
|
||||||
|
# 首次运行时会自动下载模型(可能需要一些时间)
|
||||||
|
print("正在加载模型进行测试...")
|
||||||
|
model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
|
|
||||||
|
# 说明:对一个简单的句子进行编码测试
|
||||||
|
sentence = "这是一个测试句子"
|
||||||
|
embedding = model.encode(sentence)
|
||||||
|
|
||||||
|
# 说明:检查嵌入向量的形状
|
||||||
|
print(f"安装成功!嵌入向量维度:{embedding.shape}")
|
||||||
|
print(f"前 5 个维度值:{embedding[:5]}")
|
||||||
|
|
||||||
|
# 说明:如果没有报错并输出了维度信息,说明安装成功
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
# 正则表达式
|
||||||
|
import re
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# 嵌入模型
|
||||||
|
model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
|
|
||||||
|
|
||||||
|
class Semantic_splitter:
|
||||||
|
def __init__(self, window_size, threshold):
|
||||||
|
# 设置每个窗口的大小
|
||||||
|
self.window_size = window_size
|
||||||
|
# 设置相邻窗口的相似度的阈值
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def create_documents(self, text):
|
||||||
|
# 使用正则表达式对文本进行分割
|
||||||
|
sentences = re.split(r"(。|!|?|\!|\?|\.|\n)", text)
|
||||||
|
# print(sentences)
|
||||||
|
# 初始化句子列表
|
||||||
|
sents = []
|
||||||
|
for i in range(0, len(sentences) - 1, 2):
|
||||||
|
s = sentences[i].strip() + sentences[i + 1].strip()
|
||||||
|
if s.strip():
|
||||||
|
sents.append(s)
|
||||||
|
|
||||||
|
print(sents)
|
||||||
|
|
||||||
|
# 开始使用滑动窗口依据window_size进行切块
|
||||||
|
# 初始化分块列表
|
||||||
|
docs = []
|
||||||
|
# 起始位置
|
||||||
|
start = 0
|
||||||
|
while start < len(sents):
|
||||||
|
# 结束位置
|
||||||
|
end = min(start + self.window_size, len(sents))
|
||||||
|
window = sents[start:end]
|
||||||
|
docs.append("".join(window))
|
||||||
|
start = end
|
||||||
|
print(f"{len(docs)}个块") # 分为了3块,1句和2句,3句和4句,5句
|
||||||
|
print(docs)
|
||||||
|
# 计算每个块的向量值
|
||||||
|
embeddings = model.encode(docs)
|
||||||
|
# 初始化分割点列表,起点为0
|
||||||
|
split_points = [0]
|
||||||
|
print(embeddings)
|
||||||
|
# 计算余弦相似度
|
||||||
|
for i in range(1, len(docs)):
|
||||||
|
sim = np.dot(
|
||||||
|
embeddings[i - 1],
|
||||||
|
embeddings[i]
|
||||||
|
/ (np.linalg.norm(embeddings[i - 1]) * np.linalg.norm(embeddings[i])),
|
||||||
|
)
|
||||||
|
if sim < self.threshold:
|
||||||
|
print(f"相似度低于阈值{self.threshold},在位置{i}添加分割点")
|
||||||
|
split_points.append(i)
|
||||||
|
print(split_points)
|
||||||
|
|
||||||
|
# 初始化最终分块列表
|
||||||
|
result = []
|
||||||
|
for i in range(len(split_points)):
|
||||||
|
start = split_points[i]
|
||||||
|
end = split_points[i + 1] if i + 1 < len(split_points) else len(docs)
|
||||||
|
chunk = "".join(docs[start:end])
|
||||||
|
if chunk.strip():
|
||||||
|
result.append(chunk)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
long_text = """今天天气晴朗,适合去公园散步。
|
||||||
|
|
||||||
|
量子力学中的叠加态是描述粒子同时处于多个状态的数学工具。
|
||||||
|
|
||||||
|
Windows命令行中复制文件可以使用copy命令。
|
||||||
|
|
||||||
|
大熊猫主要以竹子为食,是中国的国宝。
|
||||||
|
|
||||||
|
欧拉公式被誉为“最美的数学公式”。"""
|
||||||
|
|
||||||
|
semantic_spitter = Semantic_splitter(window_size=2, threshold=0.85)
|
||||||
|
documents = semantic_spitter.create_documents(long_text)
|
||||||
|
print(f"总共分割为{len(documents)}个块")
|
||||||
|
|
||||||
|
for i, doc in enumerate(documents, 1):
|
||||||
|
print(f"---{i}个块")
|
||||||
|
print(doc)
|
||||||
+19
@@ -0,0 +1,19 @@
|
|||||||
|
# 使用豆包来向量化文本
|
||||||
|
|
||||||
|
import os
|
||||||
|
from volcenginesdkarkruntime import Ark
|
||||||
|
|
||||||
|
# 初始化客户端
|
||||||
|
client = Ark(
|
||||||
|
# 从环境变量中读取您的方舟API Key
|
||||||
|
api_key=os.environ.get("ARK_API_KEY", "79b39c58-56db-4d8a-a8f8-84b95fca08db"),
|
||||||
|
base_url="https://ark.cn-beijing.volces.com/api/v3",
|
||||||
|
)
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model="doubao-embedding-text-240715",
|
||||||
|
input="Function Calling 是一种将大模型与外部工具和 API 相连的关键功能",
|
||||||
|
encoding_format="float",
|
||||||
|
)
|
||||||
|
# 打印结果
|
||||||
|
print(f"向量维度: {len(response.data[0].embedding)}")
|
||||||
|
print(f"前10维向量: {response.data[0].embedding[:10]}")
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
# 使用豆包来向量化文本
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
VOLC_EMBEDDINGS_API_URL = "https://ark.cn-beijing.volces.com/api/v3/embeddings"
|
||||||
|
VOLC_API_KEY = "79b39c58-56db-4d8a-a8f8-84b95fca08db"
|
||||||
|
|
||||||
|
|
||||||
|
def get_doubao_embedding(doc):
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {VOLC_API_KEY}",
|
||||||
|
}
|
||||||
|
params = {"model": "doubao-embedding-text-240715", "input": doc}
|
||||||
|
response = requests.post(VOLC_EMBEDDINGS_API_URL, json=params, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
embedding = data["data"][0]["embedding"]
|
||||||
|
return embedding
|
||||||
|
else:
|
||||||
|
raise Exception(f"Embedding API error:{response.text}")
|
||||||
|
|
||||||
|
|
||||||
|
embedding = get_doubao_embedding("这是一段文档")
|
||||||
|
print(embedding)
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,31 @@
|
|||||||
|
# 创建临时客户端
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
# 创建一个临时的内存客户端(不会保存到硬盘)
|
||||||
|
client = chromadb.EphemeralClient()
|
||||||
|
|
||||||
|
# 创建一个集合
|
||||||
|
collection = client.create_collection(name="test")
|
||||||
|
|
||||||
|
# 添加一条数据
|
||||||
|
collection.add(
|
||||||
|
documents=["今天天气有风", "很冷", "注意保暖", "加油学习"],
|
||||||
|
ids=["test_1", "test_2", "test_3", "test_4"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 查询数据
|
||||||
|
results = collection.query(query_texts=["天气"], n_results=2)
|
||||||
|
|
||||||
|
print(f"打印数据结果{results}")
|
||||||
|
|
||||||
|
# {
|
||||||
|
# 'ids': [['test_1', 'test_2']],
|
||||||
|
# 'embeddings': None,
|
||||||
|
# 'documents': [['今天天气有风', '很冷']],
|
||||||
|
# 'uris': None,
|
||||||
|
# 'included': ['metadatas', 'documents', 'distances'],
|
||||||
|
# 'data': None,
|
||||||
|
# 'metadatas': [[None, None]],
|
||||||
|
# 'distances': [[0.2988046705722809, 0.9478188753128052]]
|
||||||
|
# }
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
# 持久化存储
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
# 持久化客户端
|
||||||
|
# path指定数据存储的路径
|
||||||
|
# 如果目录不存在,Chromadb会自动创建
|
||||||
|
persistent_client = chromadb.PersistentClient(path="./chromadb_store")
|
||||||
|
|
||||||
|
|
||||||
|
# 创建一个集合(类似创建一个表)
|
||||||
|
collection = persistent_client.create_collection(
|
||||||
|
name="notes", metadata={"description": "笔记集合"} # 集合名称 # 集合元数据
|
||||||
|
)
|
||||||
|
|
||||||
|
# 列出所有集合,确认创建成功
|
||||||
|
# list_collections() 返回所有集合的列表
|
||||||
|
collections = persistent_client.list_collections()
|
||||||
|
print(collections)
|
||||||
|
|
||||||
|
for col in collections:
|
||||||
|
print(f"-{col.name}")
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
# 获取已经存在的集合
|
||||||
|
|
||||||
|
# 如果集合已经存在,可以使用get_collection() 或者 get_or_create_collection() 方法
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
# 创建持久化客户端
|
||||||
|
client = chromadb.PersistentClient(path="./chromadb_store")
|
||||||
|
|
||||||
|
# 方法1:获取已存在的集合
|
||||||
|
try:
|
||||||
|
existring_collection = client.get_collection(name="notes")
|
||||||
|
print("集合已经存在", existring_collection.name)
|
||||||
|
except Exception as e:
|
||||||
|
print("集合不存在", e)
|
||||||
|
|
||||||
|
|
||||||
|
# 方法2:获取或者创建集合(推荐使用)
|
||||||
|
|
||||||
|
collection = client.get_or_create_collection(
|
||||||
|
name="notes", metadata={"description": "笔记集合"}
|
||||||
|
)
|
||||||
|
print(collection.name)
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
# 写入数据
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
# 创建持久化客户端
|
||||||
|
|
||||||
|
client = chromadb.PersistentClient(path="./chromadb_store")
|
||||||
|
|
||||||
|
|
||||||
|
# 创建集合
|
||||||
|
collection = client.get_or_create_collection(name="knowledge_base")
|
||||||
|
|
||||||
|
# 准备说明文档
|
||||||
|
documents = [
|
||||||
|
"机器学习包含监督学习和无监督学习",
|
||||||
|
"Python 拥有丰富的数据科学生态",
|
||||||
|
"数据库可以持久化结构化或非结构化数据",
|
||||||
|
]
|
||||||
|
|
||||||
|
# 准备元组数据
|
||||||
|
metadatas = [
|
||||||
|
{"topic": "ml", "level": "intro"},
|
||||||
|
{"topic": "python", "level": "beginner"},
|
||||||
|
{"topic": "database", "level": "intro"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# 准备唯一标识
|
||||||
|
# ids 是一个列表,每个元素对应一个文档的唯一ID
|
||||||
|
# 如果不提供,Chromedb会自动生成
|
||||||
|
ids = ["doc_1", "doc_2", "doc_3"]
|
||||||
|
|
||||||
|
# 将数据添加到集合中
|
||||||
|
# add() 方法会将文档转为向量
|
||||||
|
collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||||
|
|
||||||
|
# 获取集合列表
|
||||||
|
collections = client.list_collections()
|
||||||
|
print(collections)
|
||||||
|
|
||||||
|
# 查看集合中的文档
|
||||||
|
doc_count = collection.count()
|
||||||
|
print(doc_count)
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
# 查询数据
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
# 创建持久化客户端
|
||||||
|
client = chromadb.PersistentClient(path="./chromadb_store")
|
||||||
|
|
||||||
|
# 获取已经存在的集合
|
||||||
|
collection = client.get_collection(name="knowledge_base")
|
||||||
|
|
||||||
|
# query_texts 查询文本
|
||||||
|
# n_results 返回最相似的两条结果
|
||||||
|
results = collection.query(query_texts=["如何入门机器学习"], n_results=2)
|
||||||
|
|
||||||
|
# print(results)
|
||||||
|
|
||||||
|
# {
|
||||||
|
# "ids": [["doc_1", "doc_2"]],
|
||||||
|
# "embeddings": None,
|
||||||
|
# "documents": [
|
||||||
|
# ["机器学习包含监督学习和无监督学习", "Python 拥有丰富的数据科学生态"]
|
||||||
|
# ],
|
||||||
|
# "uris": None,
|
||||||
|
# "included": ["metadatas", "documents", "distances"],
|
||||||
|
# "data": None,
|
||||||
|
# "metadatas": [
|
||||||
|
# [{"level": "intro", "topic": "ml"}, {"topic": "python", "level": "beginner"}]
|
||||||
|
# ],
|
||||||
|
# "distances": [[0.24633410573005676, 0.8512163758277893]],
|
||||||
|
# }
|
||||||
|
|
||||||
|
for idx, (doc, metadata, distances, doc_id) in enumerate(
|
||||||
|
zip(
|
||||||
|
results["documents"][0],
|
||||||
|
results["metadatas"][0],
|
||||||
|
results["distances"][0],
|
||||||
|
results["ids"][0],
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
):
|
||||||
|
print(f"结果{idx}")
|
||||||
|
print(f"文档ID{doc_id}")
|
||||||
|
print(f"匹配文档{doc}")
|
||||||
|
print(f"附加信息{metadata}")
|
||||||
|
print(f"相似度距离{distances}")
|
||||||
|
print("-" * 50)
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
# 完整流程
|
||||||
|
|
||||||
|
from chromadb import PersistentClient
|
||||||
|
|
||||||
|
# 创建持久化客户端
|
||||||
|
client = PersistentClient(path="./chromadb_store")
|
||||||
|
|
||||||
|
# 获取或者创建集合
|
||||||
|
collection = client.get_or_create_collection(name="example")
|
||||||
|
|
||||||
|
# 准备说明文档
|
||||||
|
documents = [
|
||||||
|
"机器学习包含监督学习和无监督学习",
|
||||||
|
"Python 拥有丰富的数据科学生态",
|
||||||
|
"数据库可以持久化结构化或非结构化数据",
|
||||||
|
]
|
||||||
|
# 创建元数据
|
||||||
|
metadatas = [
|
||||||
|
{"topic": "ml", "level": "intro"},
|
||||||
|
{"topic": "python", "level": "beginner"},
|
||||||
|
{"topic": "database", "level": "intro"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# ids
|
||||||
|
ids = ["doc1", "doc2", "dic3"]
|
||||||
|
|
||||||
|
# 写入数据
|
||||||
|
collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||||
|
|
||||||
|
abc = collection.get(ids=["doc2"])
|
||||||
|
print(abc)
|
||||||
|
|
||||||
|
# 查询
|
||||||
|
result = collection.query(query_texts=["如何入门机器学习"], n_results=2)
|
||||||
|
|
||||||
|
# print(result)
|
||||||
|
# {
|
||||||
|
# "ids": [["doc1", "doc2"]],
|
||||||
|
# "embeddings": None,
|
||||||
|
# "documents": [
|
||||||
|
# ["机器学习包含监督学习和无监督学习", "Python 拥有丰富的数据科学生态"]
|
||||||
|
# ],
|
||||||
|
# "uris": None,
|
||||||
|
# "included": ["metadatas", "documents", "distances"],
|
||||||
|
# "data": None,
|
||||||
|
# "metadatas": [
|
||||||
|
# [{"topic": "ml", "level": "intro"}, {"topic": "python", "level": "beginner"}]
|
||||||
|
# ],
|
||||||
|
# "distances": [[0.24633410573005676, 0.8512163758277893]],
|
||||||
|
# }
|
||||||
|
# for index, (id, doc, metadata, distance) in enumerate(
|
||||||
|
# zip(
|
||||||
|
# result["ids"][0],
|
||||||
|
# result["documents"][0],
|
||||||
|
# result["metadatas"][0],
|
||||||
|
# result["distances"][0],
|
||||||
|
# ),
|
||||||
|
# 1,
|
||||||
|
# ):
|
||||||
|
# print(f"匹配结果 {index}:")
|
||||||
|
# print(f" 文档:{doc}")
|
||||||
|
# print(f" 元数据:{metadata}")
|
||||||
|
# print(f" 距离:{distance:.4f}")
|
||||||
|
# print()
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
|
After Width: | Height: | Size: 117 KiB |
@@ -0,0 +1,5 @@
|
|||||||
|
姓名,年龄,城市,爱好,数学成绩,英语成绩,物理成绩
|
||||||
|
张三,18,北京,编程,95,88,90
|
||||||
|
李四,20,上海,篮球,87,92,85
|
||||||
|
王五,19,广州,音乐,78,85,80
|
||||||
|
赵六,21,深圳,绘画,90,89,88
|
||||||
|
Binary file not shown.
@@ -0,0 +1,14 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Document</title>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
example html
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"name": "example",
|
||||||
|
"age": 18,
|
||||||
|
"city": "Beijing"
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
markdown example
|
||||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
|||||||
|
text example
|
||||||
Binary file not shown.
@@ -0,0 +1,12 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<person>
|
||||||
|
<name>张三</name>
|
||||||
|
<age>18</age>
|
||||||
|
<city>北京</city>
|
||||||
|
<hobby>编程</hobby>
|
||||||
|
<score>
|
||||||
|
<math>95</math>
|
||||||
|
<english>88</english>
|
||||||
|
<physics>90</physics>
|
||||||
|
</score>
|
||||||
|
</person>
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
def main():
|
||||||
|
print("Hello from rag!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# client = OpenAI(
|
||||||
|
# base_url="https://api.deepseek.com/v1",
|
||||||
|
# api_key="sk-01931083835f4a539e368b209559c52c",
|
||||||
|
# )
|
||||||
|
# response = client.chat.completions.create(
|
||||||
|
# model="deepseek-chat",
|
||||||
|
# messages=[
|
||||||
|
# {"role": "system", "content": "你是谁"},
|
||||||
|
# ],
|
||||||
|
# stream=True,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# for chunk in response:
|
||||||
|
# if chunk.choices[0].delta.content is not None:
|
||||||
|
# print(chunk.choices[0].delta.content, end="", flush=True)
|
||||||
|
|
||||||
|
# 调用自己写的
|
||||||
|
from openai_client import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
api_key="sk-cc7b983a00f34cec9a12b19b64060f68",
|
||||||
|
)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="deepseek-chat",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "西游记作者是谁"},
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
# print(response.choices[0].message.content)
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.choices[0].delta.content is not None:
|
||||||
|
print(chunk.choices[0].delta.content, end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ChatCompletion(
|
||||||
|
# id='f8170d75-875c-4b46-bd3b-82a93d6be4c0',
|
||||||
|
# choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='你好!我是DeepSeek,一个由深度求索公司创造的AI助手。😊\n\n我是一个纯文本模型,虽然不支持多模态识别功能,但我有文件上传功能,可以帮你处理图像、txt、pdf、ppt、word、excel等文件,并从中读取文字信息进行分析处理。我完全免费使用,拥有128K的上下文长度,还支持联网搜索功能(需要你在Web/App中手动点开联网搜索按键)。\n\n你可以通过官方应用商店下载我的App来使用我。我很乐意帮助你解答问题、处理文档、进行对话交流等等!\n\n有什么我可以帮助你的吗?无论是学习、工作还是日常问题,我都很愿意为你提供帮助!✨', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None))],
|
||||||
|
# created=1765348625,
|
||||||
|
# model='deepseek-chat',
|
||||||
|
# object='chat.completion',
|
||||||
|
# service_tier=None,
|
||||||
|
# system_fingerprint='fp_eaab8d114b_prod0820_fp8_kvcache',
|
||||||
|
# usage=CompletionUsage(completion_tokens=143, prompt_tokens=4, total_tokens=147, completion_tokens_details=None, prompt_tokens_details=PromptTokensDetails(audio_tokens=None, cached_tokens=0), prompt_cache_hit_tokens=0, prompt_cache_miss_tokens=4))
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
from openai import OpenAI
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
import os
|
||||||
|
load_dotenv()
|
||||||
|
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||||
|
API_KEY = os.getenv("API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
base_url=DATABASE_URL,
|
||||||
|
api_key=API_KEY,
|
||||||
|
)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="doubao-seed-1-6-lite-251015",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "你是谁"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
@@ -0,0 +1,175 @@
|
|||||||
|
# 封装统一调用openai的客户端
|
||||||
|
from typing import Optional, Iterator
|
||||||
|
import os
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class Message:
|
||||||
|
def __init__(self, data):
|
||||||
|
self.content = data.get("content")
|
||||||
|
self.role = data.get("role")
|
||||||
|
|
||||||
|
|
||||||
|
class Choice:
|
||||||
|
def __init__(self, choice):
|
||||||
|
self.index = choice.get("index")
|
||||||
|
self.finish_reason = choice.get("finish_reason")
|
||||||
|
self.message = Message(choice.get("message", {}))
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse:
|
||||||
|
def __init__(self, data) -> None:
|
||||||
|
self.id = data.get("id")
|
||||||
|
self.object = data.get("object")
|
||||||
|
self.created = data.get("created")
|
||||||
|
self.model = data.get("model")
|
||||||
|
choices_data = data.get("choices", [])
|
||||||
|
self.choices = [Choice(choice) for choice in choices_data]
|
||||||
|
usage_data = data.get("usage", {})
|
||||||
|
self.usage = {
|
||||||
|
"prompt_tokens": usage_data.get("prompt_tokens"),
|
||||||
|
"completion_tokens": usage_data.get("completion_tokens"),
|
||||||
|
"total_tokens": usage_data.get("total_tokens"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage:
|
||||||
|
def __init__(self, data) -> None:
|
||||||
|
self.content = data.get("content")
|
||||||
|
self.role = data.get("role")
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaChoice:
|
||||||
|
def __init__(self, data):
|
||||||
|
self.index = data.get("index")
|
||||||
|
self.finish_reason = data.get("finish_reason")
|
||||||
|
self.delta = DeltaMessage(data.get("delta", {}))
|
||||||
|
|
||||||
|
|
||||||
|
class StreamChunk:
|
||||||
|
def __init__(self, data):
|
||||||
|
self.id = data.get("id")
|
||||||
|
self.object = data.get("object")
|
||||||
|
self.created = data.get("created")
|
||||||
|
self.model = data.get("model")
|
||||||
|
choices_data = data.get("choices", [])
|
||||||
|
self.choices = [DeltaChoice(choice) for choice in choices_data]
|
||||||
|
|
||||||
|
|
||||||
|
class Stream:
|
||||||
|
def __init__(self, response: requests.Response):
|
||||||
|
self.response = response
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.response.close()
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[StreamChunk]:
|
||||||
|
# 迭代器方法,逐个返回流式数据块
|
||||||
|
try:
|
||||||
|
# 逐行读取响应的内容(SSE格式)
|
||||||
|
for line in self.response.iter_lines(decode_unicode=True):
|
||||||
|
# print(line)
|
||||||
|
# data: {"id":"3eddf823-6ee6-4b14-a231-b0fd9dbc8087","object":"chat.completion.chunk","created":1765355109,"model":"deepseek-chat","system_fingerprint":"fp_eaab8d114b_prod0820_fp8_kvcache","choices":[{"index":0,"delta":{"content":"观点"},"logprobs":null,"finish_reason":null}]}
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
if line.startswith("data: "):
|
||||||
|
json_str = line[6:]
|
||||||
|
# 如果遇到DONE 结束,说明结束
|
||||||
|
if json_str.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(json_str)
|
||||||
|
yield StreamChunk(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
finally:
|
||||||
|
self.response.close()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletions:
|
||||||
|
def __init__(self, client):
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.7,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
url = f"{self._client.base_url}/chat/completions"
|
||||||
|
body = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
if max_tokens is not None:
|
||||||
|
body["max_tokens"] = max_tokens
|
||||||
|
if temperature is not None:
|
||||||
|
body["temperature"] = temperature
|
||||||
|
if stream:
|
||||||
|
body["stream"] = True
|
||||||
|
|
||||||
|
# 将其他参数添加到body中
|
||||||
|
body.update(kwargs)
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Authorization": f"Bearer {self._client.api_key}",
|
||||||
|
}
|
||||||
|
if stream:
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=body,
|
||||||
|
timeout=self._client.timeout,
|
||||||
|
stream=True, # 告诉openai的服务器我要使用流式输出
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return Stream(response)
|
||||||
|
else:
|
||||||
|
response = requests.post(
|
||||||
|
url, headers=headers, json=body, timeout=self._client.timeout
|
||||||
|
)
|
||||||
|
# 如果响应状态不是2xx则直接报错
|
||||||
|
# response.raise_for_status()是 requests库中一个非常重要的方法,用于自动检查 HTTP 响应状态码,并在状态码表示错误时抛出异常。
|
||||||
|
# 如果状态码是 2xx(成功):什么都不做,继续执行
|
||||||
|
# 如果状态码是 4xx 或 5xx(客户端或服务器错误):抛出异常
|
||||||
|
response.raise_for_status()
|
||||||
|
return ChatCompletionResponse(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResource:
|
||||||
|
def __init__(self, client):
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completions(self):
|
||||||
|
return ChatCompletions(self.client)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAI:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str = "https://api.deepseek.com/v1",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
timeout: float = 60.0,
|
||||||
|
):
|
||||||
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(
|
||||||
|
f"API秘钥未设置,请设置api_key参数或设置环境变量OPENAI_API_KEY"
|
||||||
|
)
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
# 可以使用属性.的方式使用方法
|
||||||
|
@property
|
||||||
|
def chat(self):
|
||||||
|
return ChatResource(self)
|
||||||
@@ -0,0 +1,136 @@
|
|||||||
|
|
||||||
|
from typing import Optional,Iterator
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
class Message:
|
||||||
|
def __init__(self,data):
|
||||||
|
self.role = data.get('role'),
|
||||||
|
self.content = data.get('content')
|
||||||
|
class Choice:
|
||||||
|
def __init__(self,data):
|
||||||
|
self.index = data.get('index')
|
||||||
|
self.message = Message(data.get('message',{}))
|
||||||
|
self.finish_reason = data.get('finish_reason')
|
||||||
|
class ChatCompletionResponse:
|
||||||
|
def __init__(self,data):
|
||||||
|
self.id = data.get('id')
|
||||||
|
self.object = data.get('object')
|
||||||
|
self.created = data.get('created')
|
||||||
|
self.model = data.get('model')
|
||||||
|
choices_data = data.get('choices',[])
|
||||||
|
self.choices = [
|
||||||
|
Choice(choice_data) for choice_data in choices_data
|
||||||
|
]
|
||||||
|
usage_data = data.get('usage',{})
|
||||||
|
self.usage = {
|
||||||
|
"prompt_tokens":usage_data.get("prompt_tokens"),
|
||||||
|
"completion_tokens":usage_data.get("completion_tokens"),
|
||||||
|
"total_tokens":usage_data.get("total_tokens"),
|
||||||
|
}
|
||||||
|
class DeltaMessage:
|
||||||
|
def __init__(self,data):
|
||||||
|
self.content = data.get('content')
|
||||||
|
self.role = data.get('role')
|
||||||
|
class DeltaChoice:
|
||||||
|
def __init__(self,data):
|
||||||
|
self.index = data.get('index')
|
||||||
|
self.delta = DeltaMessage(data.get('delta',{}) )
|
||||||
|
self.finish_reason = data.get('finish_reason')
|
||||||
|
#流式响应数据块,表示流式响应中的一个数据块
|
||||||
|
class StreamChunk:
|
||||||
|
def __init__(self,data):
|
||||||
|
self.id = data.get('id')
|
||||||
|
self.object = data.get('object')
|
||||||
|
self.created = data.get('created')
|
||||||
|
self.model = data.get('model')
|
||||||
|
choices_data = data.get('choices',[])
|
||||||
|
self.choices = [DeltaChoice(choice_data) for choice_data in choices_data]
|
||||||
|
|
||||||
|
class Stream:
|
||||||
|
def __init__(self,response:requests.Response):
|
||||||
|
self.response=response
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
def __exit__(self,exc_type,exc_val,exc_tb):
|
||||||
|
self.response.close()
|
||||||
|
def __iter__(self)->Iterator[StreamChunk]:
|
||||||
|
#迭代器方法,逐个返回流式数据块
|
||||||
|
try:
|
||||||
|
# 逐行读取响应的内容(SSE格式)
|
||||||
|
for line in self.response.iter_lines(decode_unicode=True):
|
||||||
|
#print(line)
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
if line.startswith('data: '):
|
||||||
|
json_str = line[6:]
|
||||||
|
# 如果遇到[DONE]说明流式输出结束
|
||||||
|
if json_str.strip()=="[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
data = json.loads(json_str)
|
||||||
|
yield StreamChunk(data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.response.close()
|
||||||
|
|
||||||
|
class ChatCompletions:
|
||||||
|
def __init__(self,client):
|
||||||
|
self._client = client
|
||||||
|
def create(self,model,messages,max_tokens=1024,temperature=0.7,stream:bool=False,**kwargs):
|
||||||
|
url = f"{self._client.base_url}/chat/completions"
|
||||||
|
body = {
|
||||||
|
"model":model,
|
||||||
|
"messages":messages
|
||||||
|
}
|
||||||
|
if max_tokens is not None:
|
||||||
|
body["max_tokens"]=max_tokens
|
||||||
|
if temperature is not None:
|
||||||
|
body["temperature"]=temperature
|
||||||
|
if stream:
|
||||||
|
body["stream"]=True
|
||||||
|
#添加额外的参数到请求体中
|
||||||
|
body.update(kwargs)
|
||||||
|
headers = {
|
||||||
|
"Authorization":f"Bearer {self._client.api_key}",
|
||||||
|
"Content-Type":"application/json"
|
||||||
|
}
|
||||||
|
if stream:
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=body,
|
||||||
|
timeout=self._client.timeout,
|
||||||
|
stream=True#告诉openai的服务器我要使用流式输出
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return Stream(response)
|
||||||
|
else:
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=body,
|
||||||
|
timeout=self._client.timeout
|
||||||
|
)
|
||||||
|
# 如果响应的状态不是2XX的话,主抛异常
|
||||||
|
response.raise_for_status()
|
||||||
|
return ChatCompletionResponse(response.json())
|
||||||
|
|
||||||
|
class ChatResource:
|
||||||
|
def __init__(self,client):
|
||||||
|
self._client = client
|
||||||
|
@property
|
||||||
|
def completions(self)->ChatCompletions:
|
||||||
|
return ChatCompletions(self._client)
|
||||||
|
class OpenAI:
|
||||||
|
def __init__(self,api_key:Optional[str]=None,base_url:str="https://api.openai.com/v1",timeout:float=60.0):
|
||||||
|
self.api_key=api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError(f"API密钥未设置,请设置api_key参数或者环境变量OPENAI_API_KEY")
|
||||||
|
self.base_url = base_url.rstrip('/')
|
||||||
|
self.timeout = timeout
|
||||||
|
@property
|
||||||
|
def chat(self)->ChatResource:
|
||||||
|
return ChatResource(self)
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
from playwright.sync_api import sync_playwright
|
||||||
|
|
||||||
|
|
||||||
|
def test_website():
|
||||||
|
with sync_playwright() as p:
|
||||||
|
# 启动浏览器
|
||||||
|
browser = p.chromium.launch(headless=True) # headless=False 显示浏览器窗口
|
||||||
|
page = browser.new_page()
|
||||||
|
|
||||||
|
# 访问网页
|
||||||
|
page.goto("https://www.baidu.com")
|
||||||
|
|
||||||
|
# 截图
|
||||||
|
page.screenshot(path="example.png")
|
||||||
|
|
||||||
|
# 获取页面标题
|
||||||
|
title = page.title()
|
||||||
|
print(f"页面标题: {title}")
|
||||||
|
|
||||||
|
# 关闭浏览器
|
||||||
|
browser.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_website()
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
[project]
|
||||||
|
name = "rag"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"beautifulsoup4>=4.14.2",
|
||||||
|
"chromadb>=1.3.5",
|
||||||
|
"langchain-text-splitters>=1.0.0",
|
||||||
|
"load-dotenv>=0.1.0",
|
||||||
|
"lxml>=6.0.2",
|
||||||
|
"openai>=2.9.0",
|
||||||
|
"openpyxl>=3.1.5",
|
||||||
|
"playwright>=1.56.0",
|
||||||
|
"pymupdf>=1.26.6",
|
||||||
|
"python-docx>=1.2.0",
|
||||||
|
"python-dotenv>=1.2.1",
|
||||||
|
"python-pptx>=1.0.2",
|
||||||
|
"sentence-transformers>=5.1.2",
|
||||||
|
"volcengine>=1.0.207",
|
||||||
|
"volcengine-python-sdk[ark]>=4.0.35",
|
||||||
|
]
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
class Message {
|
||||||
|
constructor(data) {
|
||||||
|
this.data = data;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const msg = new Message({ o: "000" });
|
||||||
|
|
||||||
|
console.log(msg.data.o);
|
||||||
@@ -0,0 +1,117 @@
|
|||||||
|
# for i in range(2, 10):
|
||||||
|
# print(i)
|
||||||
|
|
||||||
|
# DEFAULT_COLLECTION_NAME = "rag"
|
||||||
|
|
||||||
|
|
||||||
|
# def abc(collection_name=DEFAULT_COLLECTION_NAME):
|
||||||
|
# print(collection_name)
|
||||||
|
|
||||||
|
|
||||||
|
# abc()
|
||||||
|
|
||||||
|
|
||||||
|
# aaa = " "
|
||||||
|
# print(not aaa.strip())
|
||||||
|
|
||||||
|
# o = {"a": 1, "b": 2}
|
||||||
|
# for value in o.values():
|
||||||
|
# # print("k", key)
|
||||||
|
# print("value", value)
|
||||||
|
|
||||||
|
# from typing import Optional, Iterator
|
||||||
|
|
||||||
|
|
||||||
|
# class ChatCompletions:
|
||||||
|
# def __init__(self, client):
|
||||||
|
# self._client = client
|
||||||
|
|
||||||
|
# def create(self, model):
|
||||||
|
# print(f"{model}")
|
||||||
|
|
||||||
|
|
||||||
|
# class ChatResource:
|
||||||
|
# def __init__(self, client):
|
||||||
|
# self._client = client
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def completions(self) -> ChatCompletions:
|
||||||
|
# return ChatCompletions(self._client)
|
||||||
|
|
||||||
|
|
||||||
|
# class OpenAI:
|
||||||
|
# def __init__(
|
||||||
|
# self,
|
||||||
|
# api_key: Optional[str] = None,
|
||||||
|
# base_url: str = "https://api.openai.com/v1",
|
||||||
|
# timeout: float = 60.0,
|
||||||
|
# ):
|
||||||
|
# self.api_key = "111"
|
||||||
|
# if not self.api_key:
|
||||||
|
# raise ValueError(
|
||||||
|
# f"API密钥未设置,请设置api_key参数或者环境变量OPENAI_API_KEY"
|
||||||
|
# )
|
||||||
|
# self.base_url = base_url.rstrip("/")
|
||||||
|
# self.timeout = timeout
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def chat(self) -> ChatResource:
|
||||||
|
# return ChatResource(self)
|
||||||
|
|
||||||
|
|
||||||
|
# client = OpenAI()
|
||||||
|
# client.chat.completions.create(model="openai")
|
||||||
|
|
||||||
|
|
||||||
|
# class Obj:
|
||||||
|
# def __init__(self, data):
|
||||||
|
# self.o = data.get("o")
|
||||||
|
|
||||||
|
|
||||||
|
# class Message:
|
||||||
|
# def __init__(self, data):
|
||||||
|
# self.obj = Obj(data)
|
||||||
|
|
||||||
|
|
||||||
|
# msg = Message({"o": "112"})
|
||||||
|
# print(msg.obj.o)
|
||||||
|
|
||||||
|
class Choice:
|
||||||
|
def __init__(self, data):
|
||||||
|
self.index = data.get("index")
|
||||||
|
self.message = data.get("message")
|
||||||
|
self.finish_reason = data.get("finish_reason")
|
||||||
|
|
||||||
|
class Completion:
|
||||||
|
def __init__(self, data):
|
||||||
|
self.id = data.get("id")
|
||||||
|
self.object = data.get("object")
|
||||||
|
self.created = data.get("created")
|
||||||
|
self.model = data.get("model")
|
||||||
|
self.choices = [Choice(choice) for choice in data.get("choices", [])]
|
||||||
|
|
||||||
|
completion = Completion({"id": "123", "object": "completion", "created": 1718000000, "model": "gpt-3.5-turbo",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": "Hello, how can I help you today?"}, "finish_reason": "stop"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
class Student:
|
||||||
|
def __init__(self, name, age, score):
|
||||||
|
self.name = name # 实例属性
|
||||||
|
self.age = age
|
||||||
|
self.score = score
|
||||||
|
|
||||||
|
def introduce(self):
|
||||||
|
print(f"我是{self.name},今年{self.age}岁,分数{self.score}")
|
||||||
|
|
||||||
|
def is_pass(self):
|
||||||
|
return self.score >= 60
|
||||||
|
|
||||||
|
# 创建一个学生对象
|
||||||
|
stu1 = Student("Alice", 20, 85)
|
||||||
|
print(stu1.name)
|
||||||
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+410
@@ -0,0 +1,410 @@
|
|||||||
|
# 导入PyMuPDF库(fitz),用于处理PDF文件
|
||||||
|
import fitz # PyMuPDF
|
||||||
|
|
||||||
|
# 导入Optional类型提示
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# 导入日志logging功能
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 获取当前模块日志记录器
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# 定义用于提取PDF所有文本内容的函数
|
||||||
|
def extract_pdf_text(pdf_path: str) -> str:
|
||||||
|
"""
|
||||||
|
提取PDF文件中的所有文本内容
|
||||||
|
|
||||||
|
参数:
|
||||||
|
pdf_path (str): PDF文件路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 合并后的所有页文本
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 文件不存在
|
||||||
|
Exception: PDF文件读取失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 打开PDF文件
|
||||||
|
pdf = fitz.open(pdf_path)
|
||||||
|
try:
|
||||||
|
# 新建一个空列表,用来存储每页文本
|
||||||
|
text_list = []
|
||||||
|
# 遍历每一页
|
||||||
|
for page in pdf:
|
||||||
|
# 获取当前页文本,并加入列表
|
||||||
|
text_list.append(page.get_text("text")) # type: ignore
|
||||||
|
# 将每页文本用换行拼接成一个大字符串
|
||||||
|
all_text = "\n".join(text_list)
|
||||||
|
# 返回拼接后的文本
|
||||||
|
return all_text
|
||||||
|
finally:
|
||||||
|
# 确保关闭PDF文件
|
||||||
|
pdf.close()
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 如果文件未找到,记录错误日志
|
||||||
|
logger.error(f"PDF文件不存在: {pdf_path}")
|
||||||
|
# 向上抛出异常
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 其他异常情况,记录错误信息
|
||||||
|
logger.error(f"提取PDF文本失败: {pdf_path}, 错误: {str(e)}")
|
||||||
|
# 抛出异常
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# 导入python-docx的Document类
|
||||||
|
from docx import Document
|
||||||
|
|
||||||
|
|
||||||
|
# 定义提取Word文档所有段落文本的函数
|
||||||
|
def extract_text_from_word(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
从Word文档中提取所有段落的文本,并以字符串返回。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path (str): Word文档的路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 文本内容字符串
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 文件不存在
|
||||||
|
Exception: Word文件读取失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 加载Word文档
|
||||||
|
doc = Document(file_path)
|
||||||
|
# 取所有段落的文本,并用换行符拼接
|
||||||
|
text = "\n".join([para.text for para in doc.paragraphs])
|
||||||
|
# 返回拼接好的文本
|
||||||
|
return text
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 文件未找到时记录日志
|
||||||
|
logger.error(f"Word文件不存在: {file_path}")
|
||||||
|
# 抛出异常
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 其它异常记录错误信息
|
||||||
|
logger.error(f"提取Word文本失败: {file_path}, 错误: {str(e)}")
|
||||||
|
# 抛出异常
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# 导入openpyxl库,用于操作Excel文件
|
||||||
|
import openpyxl
|
||||||
|
|
||||||
|
|
||||||
|
# 定义函数提取Excel文件中的所有文本
|
||||||
|
def extract_text_from_excel(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
从Excel文件中提取所有单元格内容为文本,并以字符串返回。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path (str): Excel文件路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 文本内容字符串
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 文件不存在
|
||||||
|
Exception: Excel文件读取失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 加载Excel工作簿
|
||||||
|
wb = openpyxl.load_workbook(file_path, data_only=True)
|
||||||
|
try:
|
||||||
|
# 取得活动工作表
|
||||||
|
ws = wb.active
|
||||||
|
# 新建空列表保存每一行字符串
|
||||||
|
rows = []
|
||||||
|
# 遍历所有行,只取单元格的值
|
||||||
|
for row in ws.iter_rows(values_only=True):
|
||||||
|
# 将每行单元格内容用Tab连接,空值转换为空字符串
|
||||||
|
rows.append(
|
||||||
|
"\t".join([str(cell) if cell is not None else "" for cell in row])
|
||||||
|
)
|
||||||
|
# 用换行符拼接所有行
|
||||||
|
all_text = "\n".join(rows)
|
||||||
|
# 返回最终文本
|
||||||
|
return all_text
|
||||||
|
finally:
|
||||||
|
# 关闭Excel工作簿
|
||||||
|
wb.close()
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 文件未找到时日志记录
|
||||||
|
logger.error(f"Excel文件不存在: {file_path}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 其它异常日志并抛出
|
||||||
|
logger.error(f"提取Excel文本失败: {file_path}, 错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# 导入python-pptx库的Presentation类
|
||||||
|
from pptx import Presentation
|
||||||
|
|
||||||
|
|
||||||
|
# 定义函数提取PPT文件所有文本内容
|
||||||
|
def extract_ppt_text(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
提取PPT文件中的所有文本内容,并以字符串返回。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path (str): PPT文件路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 所有文本内容(以换行符分隔)
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 文件不存在
|
||||||
|
Exception: PPT文件读取失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 加载PPT文件
|
||||||
|
ppt = Presentation(file_path)
|
||||||
|
# 新建列表存储所有文本内容
|
||||||
|
text_list = []
|
||||||
|
# 遍历PPT中的每张幻灯片
|
||||||
|
for slide in ppt.slides:
|
||||||
|
# 遍历当前幻灯片的每个形状
|
||||||
|
for shape in slide.shapes:
|
||||||
|
# 判断是否含有文本,且文本不为空
|
||||||
|
if hasattr(shape, "text") and shape.text.strip():
|
||||||
|
# 有文本时加入结果列表
|
||||||
|
text_list.append(shape.text)
|
||||||
|
# 用换行符拼接所有文本
|
||||||
|
all_text = "\n".join(text_list)
|
||||||
|
# 返回所有文本内容
|
||||||
|
return all_text
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 文件未找到时日志打印
|
||||||
|
logger.error(f"PPT文件不存在: {file_path}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 处理其它异常
|
||||||
|
logger.error(f"提取PPT文本失败: {file_path}, 错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# 导入BeautifulSoup用于解析HTML
|
||||||
|
from bs4 import BeautifulSoup # BeautifulSoup用于解析HTML
|
||||||
|
|
||||||
|
|
||||||
|
# 定义函数,从HTML文件提取所有文本内容
|
||||||
|
def extract_text_from_html(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
从指定HTML文件中提取所有文本内容
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path (str): HTML文件路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 提取的文本内容
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 文件不存在
|
||||||
|
Exception: HTML文件读取失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 以utf-8编码方式打开HTML文件
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
# 读取HTML文件所有内容
|
||||||
|
html = f.read()
|
||||||
|
# 创建BeautifulSoup对象
|
||||||
|
soup = BeautifulSoup(html, "html.parser")
|
||||||
|
# 用换行分隔符获取全部文本
|
||||||
|
text = soup.get_text(separator="\n", strip=True)
|
||||||
|
# 返回文本
|
||||||
|
return text
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 文件不存在,记录日志
|
||||||
|
logger.error(f"HTML文件不存在: {file_path}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 其它异常记录并抛出
|
||||||
|
logger.error(f"提取HTML文本失败: {file_path}, 错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# 导入内置json库
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
# 定义提取JSON文件文本内容的函数
|
||||||
|
def extract_text_from_json(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
从JSON文件中提取文本内容并格式化为字符串
|
||||||
|
|
||||||
|
参数:
|
||||||
|
filename (str): JSON文件路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 格式化后的JSON文本内容
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 文件不存在
|
||||||
|
json.JSONDecodeError: JSON解析失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 以utf-8编码打开JSON文件
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
# 加载JSON内容到Python对象
|
||||||
|
data = json.load(f)
|
||||||
|
# 格式化JSON为缩进文本,显示中文
|
||||||
|
text = json.dumps(data, ensure_ascii=False, indent=2)
|
||||||
|
# 返回字符串格式JSON内容
|
||||||
|
return text
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 文件不存在时记录日志
|
||||||
|
logger.error(f"JSON文件不存在: {filename}")
|
||||||
|
raise
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
# JSON解析异常日志
|
||||||
|
logger.error(f"JSON解析失败: {filename}, 错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# 导入lxml库的etree模块用于XML处理
|
||||||
|
from lxml import etree
|
||||||
|
|
||||||
|
|
||||||
|
# 定义函数,从XML文件提取所有文本内容
|
||||||
|
def extract_xml_text(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
读取XML文件并提取所有文本内容
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path (str): XML文件路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 提取的所有文本内容
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 文件不存在
|
||||||
|
etree.XMLSyntaxError: XML解析失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 用utf-8编码打开XML文件
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
# 读取XML字符串内容
|
||||||
|
xml = f.read()
|
||||||
|
# 解析为XML树结构对象
|
||||||
|
root = etree.fromstring(xml.encode("utf-8"))
|
||||||
|
# 遍历所有文本节点并用空格拼接
|
||||||
|
text = " ".join(root.itertext())
|
||||||
|
# 返回拼接后的文本
|
||||||
|
return text
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 文件不存在日志
|
||||||
|
logger.error(f"XML文件不存在: {file_path}")
|
||||||
|
raise
|
||||||
|
except etree.XMLSyntaxError as e:
|
||||||
|
# XML语法异常日志
|
||||||
|
logger.error(f"XML解析失败: {file_path}, 错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 其它异常日志
|
||||||
|
logger.error(f"提取XML文本失败: {file_path}, 错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# 导入csv模块
|
||||||
|
import csv
|
||||||
|
|
||||||
|
|
||||||
|
# 定义读取CSV内容并串成字符串的函数
|
||||||
|
def read_csv_to_text(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
读取CSV文件内容,并将每行用逗号连接,所有行用换行符拼接成一个字符串返回。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
filename (str): CSV文件路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 拼接后的字符串
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 文件不存在
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 以utf-8编码方式打开CSV文件
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
# 创建csv.reader对象逐行读取
|
||||||
|
reader = csv.reader(f)
|
||||||
|
# 每行用逗号拼接并放到列表
|
||||||
|
rows = [", ".join(row) for row in reader]
|
||||||
|
# 用换行拼接所有行
|
||||||
|
all_text = "\n".join(rows)
|
||||||
|
# 返回结果
|
||||||
|
return all_text
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 文件不存在日志
|
||||||
|
logger.error(f"CSV文件不存在: {filename}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 其它异常日志
|
||||||
|
logger.error(f"读取CSV文件失败: {filename}, 错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# 定义读取文本文件内容的函数
|
||||||
|
def read_text_file(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
读取指定文本文件内容并返回
|
||||||
|
|
||||||
|
参数:
|
||||||
|
filename (str): 文件路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 文件内容字符串
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 文件不存在
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 以utf-8只读方式打开文本文件
|
||||||
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
|
# 读取文件的所有内容
|
||||||
|
text = f.read()
|
||||||
|
# 返回字符串
|
||||||
|
return text
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 文件未找到记录日志
|
||||||
|
logger.error(f"文本文件不存在: {filename}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 其它异常情况日志记录
|
||||||
|
logger.error(f"读取文本文件失败: {filename}, 错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# 定义读取Markdown文件内容的函数
|
||||||
|
def read_markdown_file(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
读取Markdown文件内容并返回
|
||||||
|
|
||||||
|
参数:
|
||||||
|
file_path (str): Markdown文件路径
|
||||||
|
|
||||||
|
返回:
|
||||||
|
str: 文件内容字符串
|
||||||
|
|
||||||
|
异常:
|
||||||
|
FileNotFoundError: 文件不存在
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 以utf-8编码只读打开Markdown文件
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
# 读取并返回全部内容
|
||||||
|
return f.read()
|
||||||
|
except FileNotFoundError:
|
||||||
|
# 文件不存在日志
|
||||||
|
logger.error(f"Markdown文件不存在: {file_path}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# 其它异常日志
|
||||||
|
logger.error(f"读取Markdown文件失败: {file_path}, 错误: {str(e)}")
|
||||||
|
raise
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import extract
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def extractTextAuto(file_path: str) -> str:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"文件不存在:{file_path}")
|
||||||
|
# 获取文件拓展名
|
||||||
|
ext = os.path.splitext(file_path)[-1].lower()
|
||||||
|
try:
|
||||||
|
# 如果是pdf文件
|
||||||
|
if ext == ".pdf":
|
||||||
|
logger.info(f"检测到PDF文件,开始提取文本: {file_path}")
|
||||||
|
return extract.extract_pdf_text(file_path)
|
||||||
|
# 如果是Word文档
|
||||||
|
elif ext in [".docx", ".doc"]:
|
||||||
|
logger.info(f"检测到Word文件,开始提取文本: {file_path}")
|
||||||
|
return extract.extract_text_from_word(file_path)
|
||||||
|
# 如果是Excel文件
|
||||||
|
elif ext in [".xlsx", ".xls"]:
|
||||||
|
logger.info(f"检测到Excel文件,开始提取文本: {file_path}")
|
||||||
|
return extract.extract_text_from_excel(file_path)
|
||||||
|
# 如果是PPT文件
|
||||||
|
elif ext in [".pptx", ".ppt"]:
|
||||||
|
logger.info(f"检测到PPT文件,开始提取文本: {file_path}")
|
||||||
|
return extract.extract_ppt_text(file_path)
|
||||||
|
# 如果是HTML文件
|
||||||
|
elif ext in [".html", ".htm"]:
|
||||||
|
logger.info(f"检测到HTML文件,开始提取文本: {file_path}")
|
||||||
|
return extract.extract_text_from_html(file_path)
|
||||||
|
# 如果是XML文件
|
||||||
|
elif ext == ".xml":
|
||||||
|
logger.info(f"检测到XML文件,开始提取文本: {file_path}")
|
||||||
|
return extract.extract_xml_text(file_path)
|
||||||
|
# 如果是CSV文件
|
||||||
|
elif ext == ".csv":
|
||||||
|
logger.info(f"检测到CSV文件,开始提取文本: {file_path}")
|
||||||
|
return extract.read_csv_to_text(file_path)
|
||||||
|
# 如果是JSON文件
|
||||||
|
elif ext == ".json":
|
||||||
|
logger.info(f"检测到JSON文件,开始提取文本: {file_path}")
|
||||||
|
return extract.extract_text_from_json(file_path)
|
||||||
|
# 如果是纯文本、Markdown、JSONL文件
|
||||||
|
elif ext in [".md", ".txt", ".jsonl"]:
|
||||||
|
logger.info(f"检测到文本/Markdown/JSONL文件,开始读取: {file_path}")
|
||||||
|
return extract.read_text_file(file_path)
|
||||||
|
# 其余不支持的文件类型
|
||||||
|
else:
|
||||||
|
logger.error(f"不支持的文件类型: {ext}")
|
||||||
|
raise ValueError(f"不支持的文件类型: {ext}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise
|
||||||
+55
@@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
# Install SDK: pip install 'volcengine-python-sdk[ark]'
|
||||||
|
# from volcenginesdkarkruntime import Ark
|
||||||
|
|
||||||
|
# client = Ark(
|
||||||
|
# # The base URL for model invocation
|
||||||
|
# base_url="https://ark.cn-beijing.volces.com/api/v3/chat/completions",
|
||||||
|
# api_key=os.getenv("ARK_API_KEY", "79b39c58-56db-4d8a-a8f8-84b95fca08db"),
|
||||||
|
# )
|
||||||
|
|
||||||
|
# completion = client.chat.completions.create(
|
||||||
|
# # Replace with Model ID
|
||||||
|
# model="doubao-seed-1-6-lite-251015",
|
||||||
|
# messages=[
|
||||||
|
# {
|
||||||
|
# "role": "system",
|
||||||
|
# "content": "请将下面内容进行结构化处理:火山方舟是火山引擎推出的大模型服务平台,提供模型训练、推理、评测、精调等全方位功能与服务,并重点支撑大模型生态。 火山方舟通过稳定可靠的安全互信方案,保障模型提供方的模型安全与模型使用者的信息安全,加速大模型能力渗透到千行百业,助力模型提供方和使用者实现商业新增长。",
|
||||||
|
# },
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print(completion.choices[0].message.content)
|
||||||
|
|
||||||
|
# 使用豆包来向量化文本
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
VOLC_EMBEDDINGS_API_URL = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
|
||||||
|
VOLC_API_KEY = "79b39c58-56db-4d8a-a8f8-84b95fca08db"
|
||||||
|
|
||||||
|
|
||||||
|
def get_doubao_llm(prompt):
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {VOLC_API_KEY}",
|
||||||
|
}
|
||||||
|
params = {
|
||||||
|
"model": "doubao-seed-1-6-lite-251015",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": f"{prompt}"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
response = requests.post(VOLC_EMBEDDINGS_API_URL, json=params, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
print(data)
|
||||||
|
message = data["choices"][0]["message"]["content"]
|
||||||
|
return message
|
||||||
|
else:
|
||||||
|
raise Exception(f"Embedding API error:{response.text}")
|
||||||
|
|
||||||
|
|
||||||
|
answer = get_doubao_llm("红楼梦的作者是谁")
|
||||||
|
print(answer)
|
||||||
+118
@@ -0,0 +1,118 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, List
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from llm import get_doubao_llm
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 默认集合的名称
|
||||||
|
DEFAULT_COLLECTION_NAME = "rag_system"
|
||||||
|
# 返回几条数据
|
||||||
|
DEFAULT_N_RESULTS = 2
|
||||||
|
# 默认向量化模型的名称
|
||||||
|
DEFAULT_MODEL_NAME = "all-MiniLM-L6-v2"
|
||||||
|
# 定义向量模型的全局变量
|
||||||
|
_mode: Optional[SentenceTransformer] = None
|
||||||
|
# 定义chromadb客户端
|
||||||
|
_client: Optional[chromadb.PersistentClient] = None
|
||||||
|
_collection: Optional[chromadb.Collection] = None
|
||||||
|
|
||||||
|
# 默认数据库存放路径
|
||||||
|
DEFAULT_DB_PATH = "./chroma_db"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model():
|
||||||
|
global _mode
|
||||||
|
if _mode is None:
|
||||||
|
_mode = SentenceTransformer(DEFAULT_MODEL_NAME)
|
||||||
|
return _mode
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client():
|
||||||
|
global _client
|
||||||
|
if _client is None:
|
||||||
|
_client = chromadb.PersistentClient(path=DEFAULT_DB_PATH)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
def get_query_embedding(query: str) -> List[float]:
|
||||||
|
model = _get_model()
|
||||||
|
embedding = model.encode([query])[0].tolist()
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
def _get_collection(collection_name: str = DEFAULT_COLLECTION_NAME):
|
||||||
|
global _collection
|
||||||
|
if _collection is None:
|
||||||
|
client = _get_client()
|
||||||
|
_collection = client.get_or_create_collection(collection_name)
|
||||||
|
return _collection
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve_relate_chunks(
|
||||||
|
query_embedding: List[float],
|
||||||
|
n_results: int = DEFAULT_N_RESULTS,
|
||||||
|
collection_name: str = DEFAULT_COLLECTION_NAME,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
collection = _get_collection(collection_name)
|
||||||
|
# print(n_results)
|
||||||
|
# 去指定集合查找相似度检索,找到数据
|
||||||
|
results = collection.query(
|
||||||
|
query_embeddings=[query_embedding], n_results=n_results
|
||||||
|
)
|
||||||
|
related_chunks = results.get("documents")
|
||||||
|
if not related_chunks or not related_chunks[0]:
|
||||||
|
raise ValueError("未找到相关内容")
|
||||||
|
return related_chunks[0]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"向量检索失败:{str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def query_rag(
|
||||||
|
query: str,
|
||||||
|
n_results: int = DEFAULT_N_RESULTS,
|
||||||
|
collection_name: str = DEFAULT_COLLECTION_NAME,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
查询函数:
|
||||||
|
query:用户查询的问题
|
||||||
|
n_results:检索数量
|
||||||
|
collection_name: 集合名字
|
||||||
|
"""
|
||||||
|
# 1. 将查询问题转为向量
|
||||||
|
query_embedding = get_query_embedding(query)
|
||||||
|
# print(query_embedding)
|
||||||
|
# 基于查询向量做检索
|
||||||
|
related_chunks = retrieve_relate_chunks(
|
||||||
|
query_embedding, n_results, collection_name=collection_name
|
||||||
|
)
|
||||||
|
# print("related_chunks", related_chunks)
|
||||||
|
content = "\n".join(related_chunks)
|
||||||
|
prompt = f"""
|
||||||
|
已知信息:{content}
|
||||||
|
请根据上述内容回答用户问题:{query}
|
||||||
|
"""
|
||||||
|
print(prompt)
|
||||||
|
answer = get_doubao_llm(prompt)
|
||||||
|
return answer
|
||||||
|
|
||||||
|
|
||||||
|
query = "西游记是谁写的"
|
||||||
|
|
||||||
|
try:
|
||||||
|
answer = query_rag(query, n_results=1)
|
||||||
|
print(f"答案:", answer)
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"错误{e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误{e}")
|
||||||
+65
@@ -0,0 +1,65 @@
|
|||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from vectorstore import save_text_to_db
|
||||||
|
from extract_text_auto import extractTextAuto
|
||||||
|
|
||||||
|
# 日志打印格式
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 默认保存chromadb集合名称
|
||||||
|
DEFAULT_COLLECTION_NAME = "rag_system"
|
||||||
|
# 默认分块大小
|
||||||
|
DEFAULT_CHUNK_SIZE = 200
|
||||||
|
# 默认分块重叠度
|
||||||
|
DEFAULT_CHUNK_OVERLAP = 30
|
||||||
|
|
||||||
|
|
||||||
|
def doc_to_vectorstore(
|
||||||
|
file_path: str,
|
||||||
|
collection_name: str = DEFAULT_COLLECTION_NAME,
|
||||||
|
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||||
|
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
提供文档内容,并分块保存到向量数据库中
|
||||||
|
参数:
|
||||||
|
file_path:文件路径
|
||||||
|
collection_name:集合名称
|
||||||
|
chunk_size:分块大小
|
||||||
|
chunk_overlap:分块重叠
|
||||||
|
"""
|
||||||
|
# 1. 先加载文件
|
||||||
|
text = extractTextAuto(file_path)
|
||||||
|
print(text)
|
||||||
|
if not text.strip():
|
||||||
|
logger.warning(f"文件内容为空:{file_path}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 2.进行分块
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
|
chunks = text_splitter.split_text(text)
|
||||||
|
logger.info(f"文件分块完成,共分为{len(chunks)}块")
|
||||||
|
|
||||||
|
# 3.将分好的块,保存到向量化,且保存到向量数据库中
|
||||||
|
success_count = 0
|
||||||
|
for idx, chunk in enumerate(chunks):
|
||||||
|
try:
|
||||||
|
save_text_to_db(chunk, collection_name=collection_name)
|
||||||
|
success_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存第{idx+1}块失败:{str(e)}")
|
||||||
|
logger.info(
|
||||||
|
f"文件{file_path}已经完成向量化并入库,成功保存{success_count}/{len(chunks)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
doc_to_vectorstore("西游记.txt")
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
import chromadb
|
||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 默认集合名称
|
||||||
|
DEFAULT_COLLECTION_NAME = "rag"
|
||||||
|
# 默认向量化模型名称
|
||||||
|
DEFAULT_MODEL_NAME = "all-MiniLM-L6-v2"
|
||||||
|
# 默认数据库存放路径
|
||||||
|
DEFAULT_DB_PATH = "./chroma_db"
|
||||||
|
|
||||||
|
# 定义全局mode
|
||||||
|
_model: Optional[SentenceTransformer] = None
|
||||||
|
# 定义全局客户端
|
||||||
|
_client: Optional[chromadb.PersistentClient] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mode():
|
||||||
|
global _model
|
||||||
|
if _model is None:
|
||||||
|
_model = SentenceTransformer(DEFAULT_MODEL_NAME)
|
||||||
|
return _model
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client():
|
||||||
|
global _client
|
||||||
|
if _client is None:
|
||||||
|
_client = chromadb.PersistentClient(path=DEFAULT_DB_PATH)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
def save_text_to_db(text: str, collection_name=DEFAULT_COLLECTION_NAME):
|
||||||
|
try:
|
||||||
|
if not text or not text.strip():
|
||||||
|
logger.warning("空文本,已跳过")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 获取模型
|
||||||
|
mode = _get_mode()
|
||||||
|
# 获取客户端
|
||||||
|
client = _get_client()
|
||||||
|
# 创建集合
|
||||||
|
collection = client.get_or_create_collection(collection_name)
|
||||||
|
# 创建hash id
|
||||||
|
text_id = hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||||
|
existing = collection.get(ids=[text_id])
|
||||||
|
if existing and existing.get("ids"):
|
||||||
|
logger.info(f"此文本已保存过,跳过保存,id={text_id}")
|
||||||
|
return text_id
|
||||||
|
# 生成文本的embedding模型处理结果 ndarray,通过tolist转为列表
|
||||||
|
embedding = mode.encode([text])[0].tolist()
|
||||||
|
|
||||||
|
# 添加到向量数据库中
|
||||||
|
collection.add(
|
||||||
|
documents=[text],
|
||||||
|
embeddings=[embedding],
|
||||||
|
ids=[text_id],
|
||||||
|
metadatas=[{"source": "document"}],
|
||||||
|
)
|
||||||
|
return text_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存文本向量库失败{str(e)}")
|
||||||
|
raise
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
西游记作者吴承恩
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
# 尝试导入文本分割器类
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
|
# 如果没有报错,说明安装成功
|
||||||
|
print("langchain-text-splitters 安装成功!")
|
||||||
Reference in New Issue
Block a user