mirror of
https://github.com/Tencent/WeKnora.git
synced 2025-11-25 19:37:45 +08:00
229 lines
7.3 KiB
Python
229 lines
7.3 KiB
Python
import base64
|
||
import logging
|
||
import os
|
||
import re
|
||
import uuid
|
||
from typing import Dict, List, Match, Optional, Tuple
|
||
|
||
from docreader.models.document import Document
|
||
from docreader.parser.base_parser import BaseParser
|
||
from docreader.parser.chain_parser import PipelineParser
|
||
from docreader.utils import endecode
|
||
|
||
# Get logger object
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class MarkdownTableUtil:
|
||
def __init__(self):
|
||
self.align_pattern = re.compile(
|
||
r"^([\t ]*)\|[\t ]*[:-]+(?:[\t ]*\|[\t ]*[:-]+)*[\t ]*\|[\t ]*$",
|
||
re.MULTILINE,
|
||
)
|
||
self.line_pattern = re.compile(
|
||
r"^([\t ]*)\|[\t ]*[^|\r\n]*(?:[\t ]*\|[^|\r\n]*)*\|[\t ]*$",
|
||
re.MULTILINE,
|
||
)
|
||
|
||
def format_table(self, content: str) -> str:
|
||
def process_align(match: Match[str]) -> str:
|
||
columns = [col.strip() for col in match.group(0).split("|") if col.strip()]
|
||
|
||
processed = []
|
||
for col in columns:
|
||
left_colon = ":" if col.startswith(":") else ""
|
||
right_colon = ":" if col.endswith(":") else ""
|
||
processed.append(left_colon + "---" + right_colon)
|
||
|
||
prefix = match.group(1)
|
||
return prefix + "| " + " | ".join(processed) + " |"
|
||
|
||
def process_line(match: Match[str]) -> str:
|
||
columns = [col.strip() for col in match.group(0).split("|") if col.strip()]
|
||
|
||
prefix = match.group(1)
|
||
return prefix + "| " + " | ".join(columns) + " |"
|
||
|
||
formatted_content = content
|
||
formatted_content = self.line_pattern.sub(process_line, formatted_content)
|
||
formatted_content = self.align_pattern.sub(process_align, formatted_content)
|
||
|
||
return formatted_content
|
||
|
||
@staticmethod
|
||
def _self_test():
|
||
test_content = """
|
||
# 测试表格
|
||
普通文本---不会被匹配
|
||
|
||
## 表格1(无前置空格)
|
||
|
||
| 姓名 | 年龄 | 城市 |
|
||
| :---------- | -------: | :------ |
|
||
| 张三 | 25 | 北京 |
|
||
|
||
## 表格3(前置4个空格+首尾|)
|
||
| 产品 | 价格 | 库存 |
|
||
| :-------------: | ----------- | :-----------: |
|
||
| 手机 | 5999 | 100 |
|
||
"""
|
||
util = MarkdownTableUtil()
|
||
format_content = util.format_table(test_content)
|
||
print(format_content)
|
||
|
||
|
||
class MarkdownTableFormatter(BaseParser):
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.table_helper = MarkdownTableUtil()
|
||
|
||
def parse_into_text(self, content: bytes) -> Document:
|
||
text = endecode.decode_bytes(content)
|
||
text = self.table_helper.format_table(text)
|
||
return Document(content=text)
|
||
|
||
|
||
class MarkdownImageUtil:
|
||
def __init__(self):
|
||
self.b64_pattern = re.compile(
|
||
r"!\[([^\]]*)\]\(data:image/(\w+)\+?\w*;base64,([^\)]+)\)"
|
||
)
|
||
self.image_pattern = re.compile(r"!\[([^\]]*)\]\(([^)]+)\)")
|
||
self.replace_pattern = re.compile(r"!\[([^\]]*)\]\(([^)]+)\)")
|
||
|
||
def extract_image(
|
||
self,
|
||
content: str,
|
||
path_prefix: Optional[str] = None,
|
||
replace: bool = True,
|
||
) -> Tuple[str, List[str]]:
|
||
"""Extract base64 encoded images from Markdown content"""
|
||
|
||
# image_path => base64 bytes
|
||
images: List[str] = []
|
||
|
||
def repl(match: Match[str]) -> str:
|
||
title = match.group(1)
|
||
image_path = match.group(2)
|
||
if path_prefix:
|
||
image_path = f"{path_prefix}/{image_path}"
|
||
|
||
images.append(image_path)
|
||
|
||
if not replace:
|
||
return match.group(0)
|
||
|
||
# Replace image path with URL
|
||
return f""
|
||
|
||
text = self.image_pattern.sub(repl, content)
|
||
logger.debug(f"Extracted {len(images)} images from markdown")
|
||
return text, images
|
||
|
||
def extract_base64(
|
||
self,
|
||
content: str,
|
||
path_prefix: Optional[str] = None,
|
||
replace: bool = True,
|
||
) -> Tuple[str, Dict[str, bytes]]:
|
||
"""Extract base64 encoded images from Markdown content"""
|
||
|
||
# image_path => base64 bytes
|
||
images: Dict[str, bytes] = {}
|
||
|
||
def repl(match: Match[str]) -> str:
|
||
title = match.group(1)
|
||
img_ext = match.group(2)
|
||
img_b64 = match.group(3)
|
||
|
||
image_byte = endecode.encode_image(img_b64, errors="ignore")
|
||
if not image_byte:
|
||
logger.error(f"Failed to decode base64 image skip it: {img_b64}")
|
||
return title
|
||
|
||
image_path = f"{uuid.uuid4()}.{img_ext}"
|
||
if path_prefix:
|
||
image_path = f"{path_prefix}/{image_path}"
|
||
images[image_path] = image_byte
|
||
|
||
if not replace:
|
||
return match.group(0)
|
||
|
||
# Replace image path with URL
|
||
return f""
|
||
|
||
text = self.b64_pattern.sub(repl, content)
|
||
logger.debug(f"Extracted {len(images)} base64 images from markdown")
|
||
return text, images
|
||
|
||
def replace_path(self, content: str, images: Dict[str, str]) -> str:
|
||
content_replace: set = set()
|
||
|
||
def repl(match: Match[str]) -> str:
|
||
title = match.group(1)
|
||
image_path = match.group(2)
|
||
if image_path not in images:
|
||
return match.group(0)
|
||
|
||
content_replace.add(image_path)
|
||
image_path = images[image_path]
|
||
return f""
|
||
|
||
text = self.replace_pattern.sub(repl, content)
|
||
logger.debug(f"Replaced {len(content_replace)} images in markdown")
|
||
return text
|
||
|
||
@staticmethod
|
||
def _self_test():
|
||
your_content = "testtest"
|
||
image_handle = MarkdownImageUtil()
|
||
text, images = image_handle.extract_base64(your_content)
|
||
print(text)
|
||
|
||
for image_url, image_byte in images.items():
|
||
with open(image_url, "wb") as f:
|
||
f.write(image_byte)
|
||
|
||
|
||
class MarkdownImageBase64(BaseParser):
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.image_helper = MarkdownImageUtil()
|
||
|
||
def parse_into_text(self, content: bytes) -> Document:
|
||
# Convert byte content to string using universal decoding method
|
||
text = endecode.decode_bytes(content)
|
||
text, img_b64 = self.image_helper.extract_base64(text, path_prefix="images")
|
||
|
||
images: Dict[str, str] = {}
|
||
image_replace: Dict[str, str] = {}
|
||
|
||
logger.debug(f"Uploading {len(img_b64)} images from markdown")
|
||
for ipath, b64_bytes in img_b64.items():
|
||
ext = os.path.splitext(ipath)[1].lower()
|
||
image_url = self.storage.upload_bytes(b64_bytes, ext)
|
||
|
||
image_replace[ipath] = image_url
|
||
images[image_url] = base64.b64encode(b64_bytes).decode()
|
||
|
||
text = self.image_helper.replace_path(text, image_replace)
|
||
return Document(content=text, images=images)
|
||
|
||
|
||
class MarkdownParser(PipelineParser):
|
||
_parser_cls = (MarkdownTableFormatter, MarkdownImageBase64)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
logging.basicConfig(level=logging.DEBUG)
|
||
|
||
your_content = "testtest"
|
||
parser = MarkdownParser()
|
||
|
||
document = parser.parse_into_text(your_content.encode())
|
||
logger.info(document.content)
|
||
logger.info(f"Images: {len(document.images)}, name: {document.images.keys()}")
|
||
|
||
MarkdownImageUtil._self_test()
|
||
MarkdownTableUtil._self_test()
|