Skip to content

Commit

Permalink
Merge pull request #560 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
VikParuchuri authored Feb 19, 2025
2 parents d0a0455 + a1649ef commit 27d2b9e
Show file tree
Hide file tree
Showing 18 changed files with 293 additions and 240 deletions.
5 changes: 5 additions & 0 deletions marker/builders/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class LayoutBuilder(BaseBuilder):
str,
"Skip layout and force every page to be treated as a specific block type.",
] = None
disable_tqdm: Annotated[
bool,
"Disable tqdm progress bars.",
] = False

def __init__(self, layout_model: LayoutPredictor, config=None):
self.layout_model = layout_model
Expand Down Expand Up @@ -68,6 +72,7 @@ def forced_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:


def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
self.layout_model.disable_tqdm = self.disable_tqdm
layout_results = self.layout_model(
[p.get_image(highres=False) for p in pages],
batch_size=int(self.get_batch_size())
Expand Down
7 changes: 7 additions & 0 deletions marker/builders/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class LineBuilder(BaseBuilder):
"Whether to run texify on inline math spans."
] = False
ocr_remove_blocks: Tuple[BlockTypes, ...] = (BlockTypes.Table, BlockTypes.Form, BlockTypes.TableOfContents, BlockTypes.Equation)
disable_tqdm: Annotated[
bool,
"Disable tqdm progress bars.",
] = False

def __init__(self, detection_model: DetectionPredictor, inline_detection_model: InlineDetectionPredictor, ocr_error_model: OCRErrorPredictor, config=None):
super().__init__(config)
Expand Down Expand Up @@ -126,12 +130,14 @@ def get_ocr_error_batch_size(self):
return 4

def get_detection_results(self, page_images: List[Image.Image], run_detection: List[bool], do_inline_math_detection: bool):
self.detection_model.disable_tqdm = self.disable_tqdm
page_detection_results = self.detection_model(
images=page_images,
batch_size=self.get_detection_batch_size()
)
inline_detection_results = [None] * len(page_detection_results)
if do_inline_math_detection:
self.inline_detection_model.disable_tqdm = self.disable_tqdm
inline_detection_results = self.inline_detection_model(
images=page_images,
text_boxes=[[b.bbox for b in det_result.bboxes] for det_result in page_detection_results],
Expand Down Expand Up @@ -257,6 +263,7 @@ def ocr_error_detection(self, pages:List[PageGroup], provider_page_lines: Provid
page_text = '\n'.join(' '.join(s.text for s in line.spans) for line in provider_lines)
page_texts.append(page_text)

self.ocr_error_model.disable_tqdm = self.disable_tqdm
ocr_error_detection_results = self.ocr_error_model(
page_texts,
batch_size=int(self.get_ocr_error_batch_size())
Expand Down
5 changes: 5 additions & 0 deletions marker/builders/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class OcrBuilder(BaseBuilder):
"A list of languages to use for OCR.",
"Default is None."
] = None
disable_tqdm: Annotated[
bool,
"Disable tqdm progress bars.",
] = False

def __init__(self, recognition_model: RecognitionPredictor, config=None):
super().__init__(config)
Expand Down Expand Up @@ -77,6 +81,7 @@ def ocr_extraction(self, document: Document, pages: List[PageGroup], provider: P
if sum(len(b) for b in line_boxes)==0:
return

self.recognition_model.disable_tqdm = self.disable_tqdm
recognition_results = self.recognition_model(
images=images,
bboxes=line_boxes,
Expand Down
8 changes: 5 additions & 3 deletions marker/converters/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def __init__(
):
super().__init__(config)

if config is None:
config = {}

for block_type, override_block_type in self.override_map.items():
register_block_class(block_type, override_block_type)

Expand Down Expand Up @@ -132,14 +135,13 @@ def __init__(
if self.use_llm:
self.layout_builder_class = LLMLayoutBuilder

@cache
def build_document(self, filepath: str):
provider_cls = provider_from_filepath(filepath)
layout_builder = self.resolve_dependencies(self.layout_builder_class)
line_builder = self.resolve_dependencies(LineBuilder)
ocr_builder = self.resolve_dependencies(OcrBuilder)
with provider_cls(filepath, self.config) as provider:
document = DocumentBuilder(self.config)(provider, layout_builder, line_builder, ocr_builder)
provider = provider_cls(filepath, self.config)
document = DocumentBuilder(self.config)(provider, layout_builder, line_builder, ocr_builder)
structure_builder_cls = self.resolve_dependencies(StructureBuilder)
structure_builder_cls(document)

Expand Down
6 changes: 3 additions & 3 deletions marker/converters/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ class TableConverter(PdfConverter):
)
converter_block_types: List[BlockTypes] = (BlockTypes.Table, BlockTypes.Form, BlockTypes.TableOfContents)

@cache
def build_document(self, filepath: str):
provider_cls = provider_from_filepath(filepath)
layout_builder = self.resolve_dependencies(self.layout_builder_class)
line_builder = self.resolve_dependencies(LineBuilder)
ocr_builder = self.resolve_dependencies(OcrBuilder)
document_builder = DocumentBuilder(self.config)
document_builder.disable_ocr = True
with provider_cls(filepath, self.config) as provider:
document = document_builder(provider, layout_builder, line_builder, ocr_builder)

provider = provider_cls(filepath, self.config)
document = document_builder(provider, layout_builder, line_builder, ocr_builder)

for page in document.pages:
page.structure = [p for p in page.structure if p.block_type in self.converter_block_types]
Expand Down
1 change: 1 addition & 0 deletions marker/processors/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def get_batch_size(self):

def get_latex_batched(self, equation_data: List[dict]):
inference_images = [eq["image"] for eq in equation_data]
self.texify_model.disable_tqdm = self.disable_tqdm
model_output = self.texify_model(inference_images, batch_size=self.get_batch_size())
predictions = [output.text for output in model_output]

Expand Down
2 changes: 1 addition & 1 deletion marker/processors/llm/llm_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class LLMEquationProcessor(BaseLLMSimpleBlockProcessor):
str,
"The prompt to use for generating LaTeX from equations.",
"Default is a string containing the Gemini prompt."
] = """You're an expert mathematician who is good at writing LaTeX code and html for equations.
] = r"""You're an expert mathematician who is good at writing LaTeX code and html for equations.
You'll receive an image of a math block that may contain one or more equations. Your job is to write html that represents the content of the image, with the equations in LaTeX format, and fenced by delimiters.
Some guidelines:
Expand Down
2 changes: 1 addition & 1 deletion marker/processors/llm/llm_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class LLMTextProcessor(BaseLLMSimpleBlockProcessor):

block_types = (BlockTypes.Line,)
image_remove_blocks = (BlockTypes.Equation,)
text_math_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
text_math_rewriting_prompt = r"""You are a text correction expert specializing in accurately reproducing text from images.
You will receive an image of a text block and a set of extracted lines corresponding to the text in the image.
Your task is to correct any errors in the extracted lines, including math, formatting, and other inaccuracies, and output the corrected lines in a JSON format.
The number of output lines MUST match the number of input lines. Stay as faithful to the original text as possible.
Expand Down
15 changes: 11 additions & 4 deletions marker/processors/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,18 @@ class TableProcessor(BaseProcessor):
List[BlockTypes],
"Block types to remove if they're contained inside the tables."
] = (BlockTypes.Text, BlockTypes.TextInlineMath)
pdftext_workers: Annotated[
int,
"The number of workers to use for pdftext.",
] = 4
row_split_threshold: Annotated[
float,
"The percentage of rows that need to be split across the table before row splitting is active.",
] = 0.5
pdftext_workers: Annotated[
int,
"The number of workers to use for pdftext.",
] = 1
disable_tqdm: Annotated[
bool,
"Whether to disable the tqdm progress bar.",
] = False

def __init__(
self,
Expand Down Expand Up @@ -95,6 +99,7 @@ def __call__(self, document: Document):
self.assign_ocr_lines(ocr_blocks) # Handle tables where OCR is needed
assert all("table_text_lines" in t for t in table_data), "All table data must have table cells"

self.table_rec_model.disable_tqdm = self.disable_tqdm
tables: List[TableResult] = self.table_rec_model(
[t["table_image"] for t in table_data],
batch_size=self.get_table_rec_batch_size()
Expand Down Expand Up @@ -372,6 +377,8 @@ def assign_pdftext_lines(self, extract_blocks: list, filepath: str):

def assign_ocr_lines(self, ocr_blocks: list):
det_images = [t["table_image"] for t in ocr_blocks]
self.recognition_model.disable_tqdm = self.disable_tqdm
self.detection_model.disable_tqdm = self.disable_tqdm
ocr_results: List[OCRResult] = self.recognition_model(
det_images,
[None] * len(det_images),
Expand Down
3 changes: 0 additions & 3 deletions marker/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,3 @@ def get_page_refs(self, idx: int) -> List[Reference]:

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
raise NotImplementedError
3 changes: 0 additions & 3 deletions marker/providers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ def __init__(self, filepath: str, config=None):
def __len__(self):
return self.image_count

def __exit__(self, exc_type, exc_value, traceback):
pass

def get_images(self, idxs: List[int], dpi: int) -> List[Image.Image]:
return [self.images[i] for i in idxs]

Expand Down
61 changes: 33 additions & 28 deletions marker/providers/pdf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import atexit
import contextlib
import ctypes
import re
from typing import Annotated, Dict, List, Optional, Set
Expand All @@ -9,7 +9,7 @@
from pdftext.extraction import dictionary_output
from pdftext.schema import Reference
from PIL import Image
from pypdfium2 import PdfiumError
from pypdfium2 import PdfiumError, PdfDocument

from marker.providers import BaseProvider, ProviderOutput, Char, ProviderPageLines
from marker.providers.utils import alphanum_ratio
Expand All @@ -33,7 +33,7 @@ class PdfProvider(BaseProvider):
pdftext_workers: Annotated[
int,
"The number of workers to use for pdftext.",
] = 4
] = 1
flatten_pdf: Annotated[
bool,
"Whether to flatten the PDF structure.",
Expand Down Expand Up @@ -74,33 +74,37 @@ class PdfProvider(BaseProvider):
def __init__(self, filepath: str, config=None):
super().__init__(filepath, config)

self.doc: pdfium.PdfDocument = pdfium.PdfDocument(self.filepath)
self.page_lines: ProviderPageLines = {i: [] for i in range(len(self.doc))}
self.page_refs: Dict[int, List[Reference]] = {i: [] for i in range(len(self.doc))}
self.filepath = filepath

if self.page_range is None:
self.page_range = range(len(self.doc))
with self.get_doc() as doc:
self.page_count = len(doc)
self.page_lines: ProviderPageLines = {i: [] for i in range(len(doc))}
self.page_refs: Dict[int, List[Reference]] = {i: [] for i in range(len(doc))}

assert max(self.page_range) < len(self.doc) and min(self.page_range) >= 0, \
f"Invalid page range, values must be between 0 and {len(self.doc) - 1}. Min of provided page range is {min(self.page_range)} and max is {max(self.page_range)}."
if self.page_range is None:
self.page_range = range(len(doc))

if self.force_ocr:
# Manually assign page bboxes, since we can't get them from pdftext
self.page_bboxes = {i: self.doc[i].get_bbox() for i in self.page_range}
else:
self.page_lines = self.pdftext_extraction()
assert max(self.page_range) < len(doc) and min(self.page_range) >= 0, \
f"Invalid page range, values must be between 0 and {len(doc) - 1}. Min of provided page range is {min(self.page_range)} and max is {max(self.page_range)}."

atexit.register(self.cleanup_pdf_doc)
if self.force_ocr:
# Manually assign page bboxes, since we can't get them from pdftext
self.page_bboxes = {i: doc[i].get_bbox() for i in self.page_range}
else:
self.page_lines = self.pdftext_extraction(doc)

def __exit__(self, exc_type, exc_value, traceback):
self.cleanup_pdf_doc()
@contextlib.contextmanager
def get_doc(self):
doc = None
try:
doc = pdfium.PdfDocument(self.filepath)
yield doc
finally:
if doc:
doc.close()

def __len__(self) -> int:
return len(self.doc)

def cleanup_pdf_doc(self):
if self.doc is not None:
self.doc.close()
return self.page_count

def font_flags_to_format(self, flags: Optional[int]) -> Set[str]:
if flags is None:
Expand Down Expand Up @@ -166,7 +170,7 @@ def normalize_spaces(text):
text = text.replace(space, ' ')
return text

def pdftext_extraction(self) -> ProviderPageLines:
def pdftext_extraction(self, doc: PdfDocument) -> ProviderPageLines:
page_lines: ProviderPageLines = {}
page_char_blocks = dictionary_output(
self.filepath,
Expand All @@ -185,7 +189,7 @@ def pdftext_extraction(self) -> ProviderPageLines:
for page in page_char_blocks:
page_id = page["page"]
lines: List[ProviderOutput] = []
if not self.check_page(page_id):
if not self.check_page(page_id, doc):
continue

for block in page["blocks"]:
Expand Down Expand Up @@ -247,8 +251,8 @@ def check_line_spans(self, page_lines: List[ProviderOutput]) -> bool:
return False
return True

def check_page(self, page_id: int) -> bool:
page = self.doc.get_page(page_id)
def check_page(self, page_id: int, doc: PdfDocument) -> bool:
page = doc.get_page(page_id)
page_bbox = PolygonBox.from_bbox(page.get_bbox())
try:
page_objs = list(page.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_TEXT, pdfium_c.FPDF_PAGEOBJ_IMAGE]))
Expand Down Expand Up @@ -322,7 +326,8 @@ def _render_image(pdf: pdfium.PdfDocument, idx: int, dpi: int) -> Image.Image:
return image

def get_images(self, idxs: List[int], dpi: int) -> List[Image.Image]:
images = [self._render_image(self.doc, idx, dpi) for idx in idxs]
with self.get_doc() as doc:
images = [self._render_image(doc, idx, dpi) for idx in idxs]
return images

def get_page_bbox(self, idx: int) -> PolygonBox | None:
Expand Down
Loading

0 comments on commit 27d2b9e

Please sign in to comment.