Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import asyncio | |
from urllib.parse import quote | |
from dotenv import load_dotenv | |
from io import BufferedIOBase | |
from typing import List, Optional, Union | |
from pathlib import Path | |
from botocore.exceptions import ClientError | |
from botocore.config import Config | |
from boto3.session import Session | |
from pydantic import PrivateAttr | |
from llama_index.core.async_utils import run_jobs | |
from llama_parse import LlamaParse | |
from llama_parse.utils import ( | |
nest_asyncio_err, | |
nest_asyncio_msg, | |
) | |
from llama_index.core.schema import Document | |
load_dotenv() | |
FileInput = Union[str, bytes, BufferedIOBase] | |
class S3ImageSaver: | |
def __init__(self, bucket_name, access_key=None, secret_key=None, region_name=None): | |
self.bucket_name = bucket_name | |
self.region_name = region_name | |
self.session = Session( | |
aws_access_key_id=access_key, | |
aws_secret_access_key=secret_key, | |
region_name=self.region_name, | |
) | |
self.s3_client = self.session.client( | |
"s3", config=Config(signature_version="s3v4", region_name=self.region_name) | |
) | |
def save_image(self, image_path, title): | |
"""Saves an image to the S3 bucket.""" | |
try: | |
print("---Saving Images---") | |
title_encoded = quote(title) | |
s3_key = f"images/{title}/{os.path.basename(image_path)}" | |
with open(image_path, "rb") as file: | |
self.s3_client.upload_fileobj(file, self.bucket_name, s3_key) | |
s3_url = f"https://{self.bucket_name}.s3.{self.region_name}.amazonaws.com/images/{title_encoded}/{os.path.basename(image_path)}" | |
print(f"Image saved to S3 bucket: {s3_url}") | |
return s3_url | |
except ClientError as e: | |
print(f"Error saving image to S3: {e}") | |
return None | |
class LlamaParseWithS3(LlamaParse): | |
_s3_image_saver: S3ImageSaver = PrivateAttr() | |
def __init__(self, *args, s3_image_saver=None, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._s3_image_saver = s3_image_saver or S3ImageSaver( | |
bucket_name=os.getenv("S3_BUCKET_NAME"), | |
access_key=os.getenv("AWS_ACCESS_KEY_ID"), | |
secret_key=os.getenv("AWS_SECRET_ACCESS_KEY"), | |
region_name="us-west-2", | |
) | |
async def aget_images( | |
self, json_result: List[dict], download_path: str | |
) -> List[dict]: | |
"""Download images from the parsed result.""" | |
headers = {"Authorization": f"Bearer {self.api_key}"} | |
# make the download path | |
if not os.path.exists(download_path): | |
os.makedirs(download_path) | |
try: | |
images = [] | |
for result in json_result: | |
job_id = result["job_id"] | |
for page in result["pages"]: | |
if self.verbose: | |
print(f"> Image for page {page['page']}: {page['images']}") | |
for image in page["images"]: | |
image_name = image["name"] | |
# get the full path | |
image_path = os.path.join(download_path, f"{image_name}") | |
# get a valid image path | |
if not image_path.endswith(".png"): | |
if not image_path.endswith(".jpg"): | |
image_path += ".png" | |
image["path"] = image_path | |
image["job_id"] = job_id | |
image["original_file_path"] = result.get("file_path", None) | |
image["page_number"] = page["page"] | |
with open(image_path, "wb") as f: | |
image_url = f"{self.base_url}/api/parsing/job/{job_id}/result/image/{image_name}" | |
async with self.client_context() as client: | |
res = await client.get( | |
image_url, headers=headers, timeout=self.max_timeout | |
) | |
res.raise_for_status() | |
f.write(res.content) | |
images.append(image) | |
return images | |
except Exception as e: | |
print("Error while downloading images from the parsed result:", e) | |
if self.ignore_errors: | |
return [] | |
else: | |
raise e | |
async def aget_images_s3(self, json_result: List[dict], title) -> List[dict]: | |
images = await self.aget_images( | |
json_result, download_path="tmp/" | |
) # Download to temporary location | |
# Process each image and upload to S3 | |
for image in images: | |
image_path = image["path"] | |
try: | |
s3_url = self._s3_image_saver.save_image(image_path, title) | |
if s3_url: | |
image["image_link"] = s3_url | |
except Exception as e: | |
print(f"Error saving image to S3: {image_path} - {e}") | |
# After processing all images, delete the tmp folder | |
folder_path = "tmp/" | |
try: | |
shutil.rmtree(folder_path) # Deletes the folder and all its contents | |
print(f"Folder {folder_path} and all its contents were deleted successfully.") | |
except Exception as e: | |
print(f"Error deleting folder {folder_path}: {e}") | |
return images | |
def get_images(self, json_result: List[dict], title) -> List[dict]: | |
"""Download images from the parsed result and save them to S3.""" | |
try: | |
return asyncio.run(self.aget_images_s3(json_result, title)) | |
except RuntimeError as e: | |
if nest_asyncio_err in str(e): | |
raise RuntimeError(nest_asyncio_msg) | |
else: | |
raise e | |
def get_single_job_id(json_result): | |
# Check if the list is not empty and has at least one result | |
if json_result: | |
json_id = json_result[0].get("job_id") | |
return json_id | |
return None # Return None if the list is empty | |
# The code if we know the job id | |
# async def _aget_json( | |
# self, job_id, file_path: FileInput, extra_info: Optional[dict] = None | |
# ) -> List[dict]: | |
# """Load data from the input path.""" | |
# try: | |
# if self.verbose: | |
# print("Started parsing the file under job_id %s" % job_id) | |
# result = await self._get_job_result(job_id, "json") | |
# result["job_id"] = job_id | |
# if not isinstance(file_path, (bytes, BufferedIOBase)): | |
# result["file_path"] = str(file_path) | |
# return [result] | |
# except Exception as e: | |
# file_repr = file_path if isinstance(file_path, str) else "<bytes/buffer>" | |
# print(f"Error while parsing the file '{file_repr}':", e) | |
# if self.ignore_errors: | |
# return [] | |
# else: | |
# raise e | |
async def aget_json( | |
self, | |
file_path: Union[List[FileInput], FileInput], | |
extra_info: Optional[dict] = None, | |
) -> List[dict]: | |
"""Load data from the input path.""" | |
if isinstance(file_path, (str, Path, bytes, BufferedIOBase)): | |
return await self._aget_json(file_path, extra_info=extra_info) | |
# The code when job id known | |
# return await self._aget_json( | |
# job_id="cda0870a-b896-4140-84ea-1565e1aa1565", | |
# file_path=file_path, | |
# extra_info=extra_info, | |
# ) | |
elif isinstance(file_path, list): | |
jobs = [self._aget_json(f, extra_info=extra_info) for f in file_path] | |
try: | |
results = await run_jobs( | |
jobs, | |
workers=self.num_workers, | |
desc="Parsing files", | |
show_progress=self.show_progress, | |
) | |
# return flattened results | |
return [item for sublist in results for item in sublist] | |
except RuntimeError as e: | |
if nest_asyncio_err in str(e): | |
raise RuntimeError(nest_asyncio_msg) | |
else: | |
raise e | |
else: | |
raise ValueError( | |
"The input file_path must be a string or a list of strings." | |
) | |
async def _aload_data( | |
self, | |
job_id, | |
extra_info: Optional[dict] = None, | |
verbose: bool = False, | |
) -> List[Document]: | |
"""Load data from the input path.""" | |
try: | |
result = await self._get_job_result( | |
job_id, self.result_type.value, verbose=verbose | |
) | |
docs = [ | |
Document( | |
text=result[self.result_type.value], | |
metadata=extra_info or {}, | |
) | |
] | |
if self.split_by_page: | |
return self._get_sub_docs(docs) | |
else: | |
return docs | |
except Exception as e: | |
print(f"Error while parsing the file :", e) | |
if self.ignore_errors: | |
return [] | |
else: | |
raise e | |
async def aload_data( | |
self, | |
job_id, | |
extra_info: Optional[dict] = None, | |
) -> List[Document]: | |
"""Load data from the input path.""" | |
try: | |
return await self._aload_data( | |
job_id, extra_info=extra_info, verbose=self.verbose | |
) | |
except RuntimeError as e: | |
if nest_asyncio_err in str(e): | |
raise RuntimeError(nest_asyncio_msg) | |
else: | |
raise e | |
def load_data( | |
self, | |
job_id, | |
extra_info: Optional[dict] = None, | |
) -> List[Document]: | |
"""Load data from the input path.""" | |
try: | |
return asyncio.run(self.aload_data(job_id, extra_info)) | |
except RuntimeError as e: | |
if nest_asyncio_err in str(e): | |
raise RuntimeError(nest_asyncio_msg) | |
else: | |
raise e | |