internvl2.5 / app.py
xzerus's picture
Update app.py
c5e37aa verified
raw
history blame
2.88 kB
import torch
import torchvision.transforms as T
from PIL import Image
from transformers import AutoModel, AutoTokenizer
import gradio as gr
import logging
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ImageNet normalization values
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
"""Build preprocessing pipeline for images."""
transform = T.Compose([
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Resize((input_size, input_size), interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
return transform
def preprocess_image(image, input_size=448):
"""Preprocess the image to the required format."""
transform = build_transform(input_size)
tensor_image = transform(image).unsqueeze(0).to(torch.float32 if device == "cpu" else torch.bfloat16).to(device)
return tensor_image
# Load the model and tokenizer
logging.info("Loading model from Hugging Face Hub...")
model_path = "OpenGVLab/InternVL2_5-1B"
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True,
).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
# Add the `<image>` token if missing
if "<image>" not in tokenizer.get_vocab():
tokenizer.add_tokens(["<image>"])
model.resize_token_embeddings(len(tokenizer)) # Resize model embeddings
assert "<image>" in tokenizer.get_vocab(), "Error: `<image>` token is missing from tokenizer vocabulary."
def describe_image(image):
"""Generate a description for the uploaded image."""
try:
pixel_values = preprocess_image(image, input_size=448)
prompt = "<image>\nExtract text from the image, respond with only the extracted text."
response = model.chat(
tokenizer=tokenizer,
pixel_values=pixel_values,
question=prompt,
history=None,
return_history=False,
generation_config=dict(max_new_tokens=512, do_sample=True)
)
return response
except Exception as e:
logging.error(f"Error during processing: {e}")
return f"Error: {e}"
# Gradio Interface
interface = gr.Interface(
fn=describe_image,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Extracted Text", lines=10, interactive=False),
title="Image to Text",
description="Upload an image to extract text using the pretrained model.",
)
if __name__ == "__main__":
interface.launch(server_name="0.0.0.0", server_port=7860)