|
""" |
|
Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py |
|
But accepts preloaded model to avoid slowness in use and CUDA forking issues |
|
|
|
Loader that uses Pix2Struct models to image caption |
|
|
|
""" |
|
from typing import List, Union, Any, Tuple |
|
|
|
from langchain.docstore.document import Document |
|
from langchain_community.document_loaders import ImageCaptionLoader |
|
from utils import get_device, clear_torch_cache |
|
from PIL import Image |
|
|
|
|
|
class H2OPix2StructLoader(ImageCaptionLoader): |
|
"""Loader that extracts text from images""" |
|
|
|
def __init__(self, path_images: Union[str, List[str]] = None, model_type="google/pix2struct-textcaps-base", |
|
max_new_tokens=50): |
|
super().__init__(path_images) |
|
self._pix2struct_model = None |
|
self._model_type = model_type |
|
self._max_new_tokens = max_new_tokens |
|
|
|
def set_context(self): |
|
if get_device() == 'cuda': |
|
import torch |
|
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 |
|
if n_gpus > 0: |
|
self.context_class = torch.device |
|
self.device = 'cuda' |
|
else: |
|
self.device = 'cpu' |
|
else: |
|
self.device = 'cpu' |
|
|
|
def load_model(self): |
|
try: |
|
from transformers import AutoProcessor, Pix2StructForConditionalGeneration |
|
except ImportError: |
|
raise ValueError( |
|
"`transformers` package not found, please install with " |
|
"`pip install transformers`." |
|
) |
|
if self._pix2struct_model: |
|
self._pix2struct_model = self._pix2struct_model.to(self.device) |
|
return self |
|
self.set_context() |
|
self._pix2struct_processor = AutoProcessor.from_pretrained(self._model_type) |
|
self._pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained(self._model_type).to(self.device) |
|
return self |
|
|
|
def unload_model(self): |
|
if hasattr(self._pix2struct_model, 'cpu'): |
|
self._pix2struct_model.cpu() |
|
clear_torch_cache() |
|
|
|
def set_image_paths(self, path_images: Union[str, List[str]]): |
|
""" |
|
Load from a list of image files |
|
""" |
|
if isinstance(path_images, str): |
|
self.image_paths = [path_images] |
|
else: |
|
self.image_paths = path_images |
|
|
|
def load(self, prompt=None) -> List[Document]: |
|
if self._pix2struct_model is None: |
|
self.load_model() |
|
results = [] |
|
for path_image in self.image_paths: |
|
caption, metadata = self._get_captions_and_metadata( |
|
processor=self._pix2struct_processor, model=self._pix2struct_model, path_image=path_image |
|
) |
|
doc = Document(page_content=caption, metadata=metadata) |
|
results.append(doc) |
|
|
|
return results |
|
|
|
def _get_captions_and_metadata( |
|
self, processor: Any, model: Any, path_image: str) -> Tuple[str, dict]: |
|
""" |
|
Helper function for getting the captions and metadata of an image |
|
""" |
|
try: |
|
image = Image.open(path_image) |
|
except Exception: |
|
raise ValueError(f"Could not get image data for {path_image}") |
|
inputs = self._pix2struct_processor(images=image, return_tensors="pt") |
|
inputs = inputs.to(self.device) |
|
generated_ids = self._pix2struct_model.generate(**inputs, max_new_tokens=self._max_new_tokens) |
|
generated_text = self._pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
metadata: dict = {"image_path": path_image} |
|
return generated_text, metadata |
|
|