Add GPU memory clean up
Browse files- mineru_single.py +47 -28
mineru_single.py
CHANGED
@@ -4,6 +4,8 @@ import uuid
|
|
4 |
import json
|
5 |
import requests
|
6 |
import logging
|
|
|
|
|
7 |
|
8 |
from magic_pdf.data.dataset import PymuDocDataset
|
9 |
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
@@ -48,40 +50,57 @@ class Processor:
|
|
48 |
logger.error("Failed to initialize Processor: %s", str(e))
|
49 |
raise
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def process(self, file_url: str, key: str) -> str:
|
52 |
"""
|
53 |
Process a single PDF, returning final Markdown with irrelevant images removed.
|
54 |
"""
|
55 |
logger.info("Processing file: %s", file_url)
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
class ImageWriter(DataWriter):
|
87 |
"""
|
|
|
4 |
import json
|
5 |
import requests
|
6 |
import logging
|
7 |
+
import torch
|
8 |
+
import gc
|
9 |
|
10 |
from magic_pdf.data.dataset import PymuDocDataset
|
11 |
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
|
|
50 |
logger.error("Failed to initialize Processor: %s", str(e))
|
51 |
raise
|
52 |
|
53 |
+
def cleanup_gpu(self):
|
54 |
+
"""
|
55 |
+
Releases GPU memory, use garbage collection to clear PyTorch's CUDA cache.
|
56 |
+
This helps prevent VRAM accumulation.
|
57 |
+
"""
|
58 |
+
try:
|
59 |
+
gc.collect() #garbage collection
|
60 |
+
torch.cuda.empty_cache() # Clear memory cache on GPU
|
61 |
+
logger.info("GPU memory cleaned up.")
|
62 |
+
except Exception as e:
|
63 |
+
logger.error("Error during GPU cleanup: %s", e)
|
64 |
+
|
65 |
def process(self, file_url: str, key: str) -> str:
|
66 |
"""
|
67 |
Process a single PDF, returning final Markdown with irrelevant images removed.
|
68 |
"""
|
69 |
logger.info("Processing file: %s", file_url)
|
70 |
|
71 |
+
try:
|
72 |
+
response = requests.get(file_url)
|
73 |
+
if response.status_code != 200:
|
74 |
+
logger.error("Failed to download PDF from %s. Status code: %d", file_url, response.status_code)
|
75 |
+
raise Exception(f"Failed to download PDF: {file_url}")
|
76 |
+
|
77 |
+
pdf_bytes = response.content
|
78 |
+
logger.info("Downloaded %d bytes for file_url='%s'", len(pdf_bytes), file_url)
|
79 |
+
|
80 |
+
# Analyze PDF with OCR
|
81 |
+
dataset = PymuDocDataset(pdf_bytes)
|
82 |
+
inference = doc_analyze(
|
83 |
+
dataset,
|
84 |
+
ocr=True,
|
85 |
+
lang=self.language,
|
86 |
+
layout_model=self.layout_mode,
|
87 |
+
formula_enable=self.formula_enable,
|
88 |
+
table_enable=self.table_enable
|
89 |
+
)
|
90 |
+
logger.info("doc_analyze complete for key='%s'. Started extracting images...", key)
|
91 |
+
|
92 |
+
# Classify images and remove irrelevant ones
|
93 |
+
image_writer = ImageWriter(self.s3_writer, self.svm_model)
|
94 |
+
pipe_result = inference.pipe_ocr_mode(image_writer, lang=self.language)
|
95 |
+
logger.info("OCR pipeline completed for key='%s'.", key)
|
96 |
+
|
97 |
+
md_content = pipe_result.get_markdown(self.prefix + key + "/")
|
98 |
+
final_markdown = image_writer.remove_redundant_images(md_content)
|
99 |
+
logger.info("Completed PDF process for key='%s'. Final MD length=%d", key, len(final_markdown))
|
100 |
+
return final_markdown
|
101 |
+
finally:
|
102 |
+
# GPU memory is cleaned up after each processing.
|
103 |
+
self.cleanup_gpu()
|
104 |
|
105 |
class ImageWriter(DataWriter):
|
106 |
"""
|