mirror of
https://github.com/Tencent/WeKnora.git
synced 2025-11-25 03:15:00 +08:00
312 lines
10 KiB
Python
312 lines
10 KiB
Python
"""Token splitter."""
|
|
|
|
import itertools
|
|
import logging
|
|
import re
|
|
from typing import Callable, Generic, List, Pattern, Tuple, TypeVar
|
|
|
|
from pydantic import BaseModel, Field, PrivateAttr
|
|
|
|
from docreader.splitter.header_hook import (
|
|
HeaderTracker,
|
|
)
|
|
from docreader.utils.split import split_by_char, split_by_sep
|
|
|
|
DEFAULT_CHUNK_OVERLAP = 100
|
|
DEFAULT_CHUNK_SIZE = 512
|
|
|
|
T = TypeVar("T")
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TextSplitter(BaseModel, Generic[T]):
|
|
chunk_size: int = Field(description="The token chunk size for each chunk.")
|
|
chunk_overlap: int = Field(
|
|
description="The token overlap of each chunk when splitting."
|
|
)
|
|
separators: List[str] = Field(
|
|
description="Default separators for splitting into words"
|
|
)
|
|
|
|
# Try to keep the matched characters as a whole.
|
|
# If it's too long, the content will be further segmented.
|
|
protected_regex: List[str] = Field(
|
|
description="Protected regex for splitting into words"
|
|
)
|
|
len_function: Callable[[str], int] = Field(description="The length function.")
|
|
# Header tracking Hook related attributes
|
|
header_hook: HeaderTracker = Field(default_factory=HeaderTracker, exclude=True)
|
|
|
|
_protected_fns: List[Pattern] = PrivateAttr()
|
|
_split_fns: List[Callable] = PrivateAttr()
|
|
|
|
def __init__(
|
|
self,
|
|
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
|
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
|
separators: List[str] = ["\n", "。", " "],
|
|
protected_regex: List[str] = [
|
|
# math formula
|
|
r"\$\$[\s\S]*?\$\$",
|
|
# image
|
|
r"!\[.*?\]\(.*?\)",
|
|
# link
|
|
r"\[.*?\]\(.*?\)",
|
|
# table header
|
|
r"(?:\|[^|\n]*)+\|[\r\n]+\s*(?:\|\s*:?-{3,}:?\s*)+\|[\r\n]+",
|
|
# table body
|
|
r"(?:\|[^|\n]*)+\|[\r\n]+",
|
|
# code header
|
|
r"```(?:\w+)[\r\n]+[^\r\n]*",
|
|
],
|
|
length_function: Callable[[str], int] = lambda x: len(x),
|
|
):
|
|
"""Initialize with parameters."""
|
|
if chunk_overlap > chunk_size:
|
|
raise ValueError(
|
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
|
f"({chunk_size}), should be smaller."
|
|
)
|
|
|
|
super().__init__(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
separators=separators,
|
|
protected_regex=protected_regex,
|
|
len_function=length_function,
|
|
)
|
|
self._protected_fns = [re.compile(reg) for reg in protected_regex]
|
|
self._split_fns = [split_by_sep(sep) for sep in separators] + [split_by_char()]
|
|
|
|
def split_text(self, text: str) -> List[Tuple[int, int, str]]:
|
|
"""Split text into chunks."""
|
|
if text == "":
|
|
return []
|
|
|
|
splits = self._split(text)
|
|
protect = self._split_protected(text)
|
|
splits = self._join(splits, protect)
|
|
|
|
assert "".join(splits) == text
|
|
|
|
chunks = self._merge(splits)
|
|
return chunks
|
|
|
|
def _split(self, text: str) -> List[str]:
|
|
"""Break text into splits that are smaller than chunk size.
|
|
|
|
NOTE: the splits contain the separators.
|
|
"""
|
|
if self.len_function(text) <= self.chunk_size:
|
|
return [text]
|
|
|
|
splits = []
|
|
for split_fn in self._split_fns:
|
|
splits = split_fn(text)
|
|
if len(splits) > 1:
|
|
break
|
|
|
|
new_splits = []
|
|
for split in splits:
|
|
split_len = self.len_function(split)
|
|
if split_len <= self.chunk_size:
|
|
new_splits.append(split)
|
|
else:
|
|
# recursively split
|
|
new_splits.extend(self._split(split))
|
|
return new_splits
|
|
|
|
def _merge(self, splits: List[str]) -> List[Tuple[int, int, str]]:
|
|
"""Merge splits into chunks.
|
|
|
|
The high-level idea is to keep adding splits to a chunk until we
|
|
exceed the chunk size, then we start a new chunk with overlap.
|
|
|
|
When we start a new chunk, we pop off the first element of the previous
|
|
chunk until the total length is less than the chunk size.
|
|
"""
|
|
chunks: List[Tuple[int, int, str]] = []
|
|
|
|
cur_chunk: List[Tuple[int, int, str]] = []
|
|
|
|
cur_headers, cur_len = "", 0
|
|
cur_start, cur_end = 0, 0
|
|
for split in splits:
|
|
cur_end = cur_start + len(split)
|
|
split_len = self.len_function(split)
|
|
if split_len > self.chunk_size:
|
|
logger.error(
|
|
f"Got a split of size {split_len}, ",
|
|
f"larger than chunk size {self.chunk_size}.",
|
|
)
|
|
|
|
self.header_hook.update(split)
|
|
cur_headers = self.header_hook.get_headers()
|
|
cur_headers_len = self.len_function(cur_headers)
|
|
|
|
if cur_headers_len > self.chunk_size:
|
|
logger.error(
|
|
f"Got headers of size {cur_headers_len}, ",
|
|
f"larger than chunk size {self.chunk_size}.",
|
|
)
|
|
cur_headers, cur_headers_len = "", 0
|
|
|
|
# if we exceed the chunk size after adding the new split, then
|
|
# we need to end the current chunk and start a new one
|
|
if cur_len + split_len + cur_headers_len > self.chunk_size:
|
|
# end the previous chunk
|
|
if len(cur_chunk) > 0:
|
|
chunks.append(
|
|
(
|
|
cur_chunk[0][0],
|
|
cur_chunk[-1][1],
|
|
"".join([c[2] for c in cur_chunk]),
|
|
)
|
|
)
|
|
|
|
# start a new chunk with overlap
|
|
# keep popping off the first element of the previous chunk until:
|
|
# 1. the current chunk length is less than chunk overlap
|
|
# 2. the total length is less than chunk size
|
|
while cur_chunk and (
|
|
cur_len > self.chunk_overlap
|
|
or cur_len + split_len + cur_headers_len > self.chunk_size
|
|
):
|
|
# pop off the first element
|
|
first_chunk = cur_chunk.pop(0)
|
|
cur_len -= self.len_function(first_chunk[2])
|
|
|
|
if (
|
|
cur_headers
|
|
and split_len + cur_headers_len < self.chunk_size
|
|
and cur_headers not in split
|
|
):
|
|
cur_chunk.insert(
|
|
0,
|
|
(
|
|
cur_chunk[0][0] if cur_chunk else cur_start,
|
|
cur_chunk[0][1] if cur_chunk else cur_end,
|
|
cur_headers,
|
|
),
|
|
)
|
|
cur_len += cur_headers_len
|
|
|
|
cur_chunk.append((cur_start, cur_end, split))
|
|
cur_len += split_len
|
|
cur_start = cur_end
|
|
|
|
# handle the last chunk
|
|
assert cur_chunk
|
|
chunks.append(
|
|
(
|
|
cur_chunk[0][0],
|
|
cur_chunk[-1][1],
|
|
"".join([c[2] for c in cur_chunk]),
|
|
)
|
|
)
|
|
|
|
return chunks
|
|
|
|
def _split_protected(self, text: str) -> List[Tuple[int, str]]:
|
|
matches = [
|
|
(match.start(), match.end())
|
|
for pattern in self._protected_fns
|
|
for match in pattern.finditer(text)
|
|
]
|
|
matches.sort(key=lambda x: (x[0], -x[1]))
|
|
|
|
res = []
|
|
|
|
def fold(initial: int, current: Tuple[int, int]) -> int:
|
|
if current[0] >= initial:
|
|
if current[1] - current[0] < self.chunk_size:
|
|
res.append((current[0], text[current[0] : current[1]]))
|
|
else:
|
|
logger.warning(f"Protected text ignore: {current}")
|
|
return max(initial, current[1])
|
|
|
|
# filter overlapping matches
|
|
list(itertools.accumulate(matches, fold, initial=-1))
|
|
return res
|
|
|
|
def _join(self, splits: List[str], protect: List[Tuple[int, str]]) -> List[str]:
|
|
"""
|
|
Merges and splits elements in splits array based on protected substrings.
|
|
|
|
The function processes the input splits to ensure all protected substrings
|
|
remain as single items. If a protected substring is concatenated with preceding
|
|
or following content in any split element, it will be separated from
|
|
the adjacent content. The final result maintains the original order of content
|
|
while enforcing the integrity of protected substrings.
|
|
|
|
Key behaviors:
|
|
1. Preserves the complete structure of each protected substring
|
|
2. Separates protected substrings from any adjacent non-protected content
|
|
3. Maintains the original sequence of all content except for necessary
|
|
4. Handles cases where protected substrings are partially concatenated
|
|
"""
|
|
j = 0
|
|
point, start = 0, 0
|
|
res = []
|
|
|
|
for split in splits:
|
|
end = start + len(split)
|
|
|
|
cur = split[point - start :]
|
|
while j < len(protect):
|
|
p_start, p_content = protect[j]
|
|
p_end = p_start + len(p_content)
|
|
|
|
if end <= p_start:
|
|
break
|
|
|
|
if point < p_start:
|
|
local_end = p_start - point
|
|
res.append(cur[:local_end])
|
|
cur = cur[local_end:]
|
|
point = p_start
|
|
|
|
res.append(p_content)
|
|
j += 1
|
|
|
|
if point < p_end:
|
|
local_start = p_end - point
|
|
cur = cur[local_start:]
|
|
point = p_end
|
|
|
|
if not cur:
|
|
break
|
|
|
|
if cur:
|
|
res.append(cur)
|
|
point = end
|
|
|
|
start = end
|
|
return res
|
|
|
|
|
|
if __name__ == "__main__":
|
|
s = """
|
|
这是一些普通文本。
|
|
|
|
| 姓名 | 年龄 | 城市 |
|
|
|------|------|------|
|
|
| 张三 | 25 | 北京 |
|
|
| 李四 | 30 | 上海 |
|
|
| 王五 | 28 | 广州 |
|
|
| 张三 | 25 | 北京 |
|
|
| 李四 | 30 | 上海 |
|
|
| 王五 | 28 | 广州 |
|
|
|
|
这是文本结束。
|
|
|
|
"""
|
|
|
|
sp = TextSplitter(chunk_size=200, chunk_overlap=2)
|
|
ck = sp.split_text(s)
|
|
for c in ck:
|
|
print("------", len(c))
|
|
print(c)
|
|
pass
|