Files
WeKnora/docreader/splitter/splitter.py
begoniezhao 2d66abedf0 feat: 新增文档模型类,调整配置与解析逻辑,优化日志及导入
移除日志设置与冗余代码,优化导入、类型提示及OCR后端管理
统一调整各文件模块导入路径为绝对导入
调整导入路径,移除部分导入,优化日志及注释
升级文档解析器为 Docx2Parser,优化超时与图片处理逻辑
2025-11-18 22:37:01 +08:00

314 lines
10 KiB
Python

"""Token splitter."""
import itertools
import logging
import re
from typing import Callable, Generic, List, Pattern, Tuple, TypeVar
from pydantic import BaseModel, Field, PrivateAttr
from docreader.splitter.header_hook import (
HeaderTracker,
)
from docreader.utils.split import split_by_char, split_by_sep
DEFAULT_CHUNK_OVERLAP = 100
DEFAULT_CHUNK_SIZE = 512
T = TypeVar("T")
logger = logging.getLogger(__name__)
class TextSplitter(BaseModel, Generic[T]):
chunk_size: int = Field(description="The token chunk size for each chunk.")
chunk_overlap: int = Field(
description="The token overlap of each chunk when splitting."
)
separators: List[str] = Field(
description="Default separators for splitting into words"
)
# Try to keep the matched characters as a whole.
# If it's too long, the content will be further segmented.
protected_regex: List[str] = Field(
description="Protected regex for splitting into words"
)
len_function: Callable[[str], int] = Field(description="The length function.")
# Header tracking Hook related attributes
header_hook: HeaderTracker = Field(default_factory=HeaderTracker, exclude=True)
_protected_fns: List[Pattern] = PrivateAttr()
_split_fns: List[Callable] = PrivateAttr()
def __init__(
self,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
separators: List[str] = ["\n", "", " "],
protected_regex: List[str] = [
# math formula
r"\$\$[\s\S]*?\$\$",
# image
r"!\[.*?\]\(.*?\)",
# link
r"\[.*?\]\(.*?\)",
# table header
r"(?:\|[^|\n]*)+\|[\r\n]+\s*(?:\|\s*:?-{3,}:?\s*)+\|[\r\n]+",
# table body
r"(?:\|[^|\n]*)+\|[\r\n]+",
# code header
r"```(?:\w+)[\r\n]+[^\r\n]*",
],
length_function: Callable[[str], int] = lambda x: len(x),
):
"""Initialize with parameters."""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=separators,
protected_regex=protected_regex,
len_function=length_function,
)
self._protected_fns = [re.compile(reg) for reg in protected_regex]
self._split_fns = [split_by_sep(sep) for sep in separators] + [split_by_char()]
def split_text(self, text: str) -> List[Tuple[int, int, str]]:
"""Split text into chunks."""
if text == "":
return []
splits = self._split(text)
protect = self._split_protected(text)
splits = self._join(splits, protect)
assert "".join(splits) == text
chunks = self._merge(splits)
return chunks
def _split(self, text: str) -> List[str]:
"""Break text into splits that are smaller than chunk size.
NOTE: the splits contain the separators.
"""
if self.len_function(text) <= self.chunk_size:
return [text]
splits = []
for split_fn in self._split_fns:
splits = split_fn(text)
if len(splits) > 1:
break
new_splits = []
for split in splits:
split_len = self.len_function(split)
if split_len <= self.chunk_size:
new_splits.append(split)
else:
# recursively split
new_splits.extend(self._split(split))
return new_splits
def _merge(self, splits: List[str]) -> List[Tuple[int, int, str]]:
"""Merge splits into chunks.
The high-level idea is to keep adding splits to a chunk until we
exceed the chunk size, then we start a new chunk with overlap.
When we start a new chunk, we pop off the first element of the previous
chunk until the total length is less than the chunk size.
"""
chunks: List[Tuple[int, int, str]] = []
cur_chunk: List[Tuple[int, int, str]] = []
cur_headers, cur_len = "", 0
cur_start, cur_end = 0, 0
for split in splits:
cur_end = cur_start + len(split)
split_len = self.len_function(split)
if split_len > self.chunk_size:
logger.error(
f"Got a split of size {split_len}, ",
f"larger than chunk size {self.chunk_size}.",
)
self.header_hook.update(split)
cur_headers = self.header_hook.get_headers()
cur_headers_len = self.len_function(cur_headers)
if cur_headers_len > self.chunk_size:
logger.error(
f"Got headers of size {cur_headers_len}, ",
f"larger than chunk size {self.chunk_size}.",
)
cur_headers, cur_headers_len = "", 0
# if we exceed the chunk size after adding the new split, then
# we need to end the current chunk and start a new one
if cur_len + split_len + cur_headers_len > self.chunk_size:
# end the previous chunk
if len(cur_chunk) > 0:
chunks.append(
(
cur_chunk[0][0],
cur_chunk[-1][1],
"".join([c[2] for c in cur_chunk]),
)
)
# start a new chunk with overlap
# keep popping off the first element of the previous chunk until:
# 1. the current chunk length is less than chunk overlap
# 2. the total length is less than chunk size
while cur_chunk and (
cur_len > self.chunk_overlap
or cur_len + split_len + cur_headers_len > self.chunk_size
):
# pop off the first element
first_chunk = cur_chunk.pop(0)
cur_len -= self.len_function(first_chunk[2])
if (
cur_headers
and split_len + cur_headers_len < self.chunk_size
and cur_headers not in split
):
cur_chunk.insert(
0,
(
cur_chunk[0][0] if cur_chunk else cur_start,
cur_chunk[0][1] if cur_chunk else cur_end,
cur_headers,
),
)
cur_len += cur_headers_len
cur_chunk.append((cur_start, cur_end, split))
cur_len += split_len
cur_start = cur_end
# handle the last chunk
assert cur_chunk
if cur_headers and cur_len < self.chunk_size:
cur_chunk.insert(0, (cur_chunk[0][0], cur_chunk[0][1], cur_headers))
chunks.append(
(
cur_chunk[0][0],
cur_chunk[-1][1],
"".join([c[2] for c in cur_chunk]),
)
)
return chunks
def _split_protected(self, text: str) -> List[Tuple[int, str]]:
matches = [
(match.start(), match.end())
for pattern in self._protected_fns
for match in pattern.finditer(text)
]
matches.sort(key=lambda x: (x[0], -x[1]))
res = []
def fold(initial: int, current: Tuple[int, int]) -> int:
if current[0] >= initial:
if current[1] - current[0] < self.chunk_size:
res.append((current[0], text[current[0] : current[1]]))
else:
logger.warning(f"Protected text ignore: {current}")
return max(initial, current[1])
# filter overlapping matches
list(itertools.accumulate(matches, fold, initial=-1))
return res
def _join(self, splits: List[str], protect: List[Tuple[int, str]]) -> List[str]:
"""
Merges and splits elements in splits array based on protected substrings.
The function processes the input splits to ensure all protected substrings
remain as single items. If a protected substring is concatenated with preceding
or following content in any split element, it will be separated from
the adjacent content. The final result maintains the original order of content
while enforcing the integrity of protected substrings.
Key behaviors:
1. Preserves the complete structure of each protected substring
2. Separates protected substrings from any adjacent non-protected content
3. Maintains the original sequence of all content except for necessary
4. Handles cases where protected substrings are partially concatenated
"""
j = 0
point, start = 0, 0
res = []
for split in splits:
end = start + len(split)
cur = split[point - start :]
while j < len(protect):
p_start, p_content = protect[j]
p_end = p_start + len(p_content)
if end <= p_start:
break
if point < p_start:
local_end = p_start - point
res.append(cur[:local_end])
cur = cur[local_end:]
point = p_start
res.append(p_content)
j += 1
if point < p_end:
local_start = p_end - point
cur = cur[local_start:]
point = p_end
if not cur:
break
if cur:
res.append(cur)
point = end
start = end
return res
if __name__ == "__main__":
s = """
这是一些普通文本。
| 姓名 | 年龄 | 城市 |
|------|------|------|
| 张三 | 25 | 北京 |
| 李四 | 30 | 上海 |
| 王五 | 28 | 广州 |
| 张三 | 25 | 北京 |
| 李四 | 30 | 上海 |
| 王五 | 28 | 广州 |
这是文本结束。
"""
sp = TextSplitter(chunk_size=200, chunk_overlap=2)
ck = sp.split_text(s)
for c in ck:
print("------", len(c))
print(c)
pass