File size: 4,436 Bytes
9cf6c42 002fda6 9cf6c42 002fda6 9cf6c42 f2b2dae 002fda6 9cf6c42 c39921f f2b2dae 9cf6c42 f2b2dae 9cf6c42 f2b2dae 9cf6c42 f2b2dae 9cf6c42 f2b2dae 9cf6c42 f2b2dae 9cf6c42 f2b2dae f511cdd 002fda6 9cf6c42 002fda6 9cf6c42 f2b2dae 9cf6c42 f2b2dae 9cf6c42 7d059af f2b2dae 19f48be f2b2dae 7d059af f2b2dae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
import gradio as gr
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
# Build the image transform
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
# Dynamic preprocessing
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
target_ratios = sorted(
set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num),
key=lambda x: x[0] * x[1]
)
target_aspect_ratio = target_ratios[0]
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
resized_img = image.resize((target_width, target_height))
processed_images = [
resized_img.crop((
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
))
for i in range(blocks)
]
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
# Load image dynamically from user upload
def load_image(image, input_size=448, max_num=12):
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
# Load the model and tokenizer
path = 'OpenGVLab/InternVL2_5-78B'
model = AutoModel.from_pretrained(
path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_flash_attn=True,
trust_remote_code=True
).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
# Define the function for Gradio image interface
def process_image(image):
try:
pixel_values = load_image(image, max_num=12).to(torch.bfloat16).cuda()
generation_config = dict(max_new_tokens=1024, do_sample=True)
question = '<image>\nExtract text from the image, respond with only the extracted text.'
response = model.chat(tokenizer, pixel_values, question, generation_config)
return response
except Exception as e:
return f"Error: {str(e)}"
# Define the function for text-based chatbot interface
def chatbot(input_text, history=[]):
try:
generation_config = dict(max_new_tokens=1024, do_sample=True)
response, updated_history = model.chat(tokenizer, None, input_text, generation_config, history=history, return_history=True)
return response, updated_history
except Exception as e:
return f"Error: {str(e)}", history
# Create Gradio Tabs
with gr.Blocks() as demo:
with gr.Tab("Image Processing"):
gr.Markdown("Upload an image and get detailed responses using the InternVL model.")
image_input = gr.Image(type="pil")
image_output = gr.Textbox(label="Response")
image_btn = gr.Button("Process")
image_btn.click(process_image, inputs=image_input, outputs=image_output)
with gr.Tab("Chatbot"):
gr.Markdown("Chat with the model.")
chatbot_input = gr.Textbox(label="Your Message")
chatbot_output = gr.Textbox(label="Response")
chatbot_history = gr.State([])
chatbot_btn = gr.Button("Send")
chatbot_btn.click(chatbot, inputs=[chatbot_input, chatbot_history], outputs=[chatbot_output, chatbot_history])
# Launch the Gradio app
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)
|