Files
03Rag/14语义感知分块策略.py
T
heyong.fu a17c65c4bc feat: rag
2026-05-06 11:35:10 +08:00

91 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 正则表达式
import re
from sentence_transformers import SentenceTransformer
import numpy as np
# 嵌入模型
model = SentenceTransformer("all-MiniLM-L6-v2")
class Semantic_splitter:
def __init__(self, window_size, threshold):
# 设置每个窗口的大小
self.window_size = window_size
# 设置相邻窗口的相似度的阈值
self.threshold = threshold
def create_documents(self, text):
# 使用正则表达式对文本进行分割
sentences = re.split(r"(。|||\!|\?|\.|\n)", text)
# print(sentences)
# 初始化句子列表
sents = []
for i in range(0, len(sentences) - 1, 2):
s = sentences[i].strip() + sentences[i + 1].strip()
if s.strip():
sents.append(s)
print(sents)
# 开始使用滑动窗口依据window_size进行切块
# 初始化分块列表
docs = []
# 起始位置
start = 0
while start < len(sents):
# 结束位置
end = min(start + self.window_size, len(sents))
window = sents[start:end]
docs.append("".join(window))
start = end
print(f"{len(docs)}个块") # 分为了3块,1句和2句,3句和4句,5句
print(docs)
# 计算每个块的向量值
embeddings = model.encode(docs)
# 初始化分割点列表,起点为0
split_points = [0]
print(embeddings)
# 计算余弦相似度
for i in range(1, len(docs)):
sim = np.dot(
embeddings[i - 1],
embeddings[i]
/ (np.linalg.norm(embeddings[i - 1]) * np.linalg.norm(embeddings[i])),
)
if sim < self.threshold:
print(f"相似度低于阈值{self.threshold},在位置{i}添加分割点")
split_points.append(i)
print(split_points)
# 初始化最终分块列表
result = []
for i in range(len(split_points)):
start = split_points[i]
end = split_points[i + 1] if i + 1 < len(split_points) else len(docs)
chunk = "".join(docs[start:end])
if chunk.strip():
result.append(chunk)
return result
long_text = """今天天气晴朗,适合去公园散步。
量子力学中的叠加态是描述粒子同时处于多个状态的数学工具。
Windows命令行中复制文件可以使用copy命令。
大熊猫主要以竹子为食,是中国的国宝。
欧拉公式被誉为“最美的数学公式”。"""
semantic_spitter = Semantic_splitter(window_size=2, threshold=0.85)
documents = semantic_spitter.create_documents(long_text)
print(f"总共分割为{len(documents)}个块")
for i, doc in enumerate(documents, 1):
print(f"---{i}个块")
print(doc)