File size: 4,436 Bytes
9cf6c42
002fda6
9cf6c42
002fda6
9cf6c42
 
f2b2dae
002fda6
9cf6c42
 
c39921f
f2b2dae
9cf6c42
 
 
 
 
 
 
 
 
 
f2b2dae
9cf6c42
 
 
f2b2dae
 
 
 
 
9cf6c42
 
 
 
f2b2dae
 
9cf6c42
 
 
 
f2b2dae
 
 
9cf6c42
 
 
 
 
f2b2dae
 
9cf6c42
 
 
 
 
 
f2b2dae
f511cdd
002fda6
9cf6c42
 
002fda6
9cf6c42
f2b2dae
 
9cf6c42
f2b2dae
9cf6c42
7d059af
f2b2dae
 
 
 
19f48be
f2b2dae
 
 
 
 
7d059af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2b2dae
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
import gradio as gr

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

# Build the image transform
def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

# Dynamic preprocessing
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height
    target_ratios = sorted(
        set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num),
        key=lambda x: x[0] * x[1]
    )
    target_aspect_ratio = target_ratios[0]
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
    resized_img = image.resize((target_width, target_height))
    processed_images = [
        resized_img.crop((
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        ))
        for i in range(blocks)
    ]
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

# Load image dynamically from user upload
def load_image(image, input_size=448, max_num=12):
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

# Load the model and tokenizer
path = 'OpenGVLab/InternVL2_5-78B'
model = AutoModel.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    use_flash_attn=True,
    trust_remote_code=True
).eval().cuda()

tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)

# Define the function for Gradio image interface
def process_image(image):
    try:
        pixel_values = load_image(image, max_num=12).to(torch.bfloat16).cuda()
        generation_config = dict(max_new_tokens=1024, do_sample=True)
        question = '<image>\nExtract text from the image, respond with only the extracted text.'
        response = model.chat(tokenizer, pixel_values, question, generation_config)
        return response
    except Exception as e:
        return f"Error: {str(e)}"

# Define the function for text-based chatbot interface
def chatbot(input_text, history=[]):
    try:
        generation_config = dict(max_new_tokens=1024, do_sample=True)
        response, updated_history = model.chat(tokenizer, None, input_text, generation_config, history=history, return_history=True)
        return response, updated_history
    except Exception as e:
        return f"Error: {str(e)}", history

# Create Gradio Tabs
with gr.Blocks() as demo:
    with gr.Tab("Image Processing"):
        gr.Markdown("Upload an image and get detailed responses using the InternVL model.")
        image_input = gr.Image(type="pil")
        image_output = gr.Textbox(label="Response")
        image_btn = gr.Button("Process")
        image_btn.click(process_image, inputs=image_input, outputs=image_output)

    with gr.Tab("Chatbot"):
        gr.Markdown("Chat with the model.")
        chatbot_input = gr.Textbox(label="Your Message")
        chatbot_output = gr.Textbox(label="Response")
        chatbot_history = gr.State([])
        chatbot_btn = gr.Button("Send")
        chatbot_btn.click(chatbot, inputs=[chatbot_input, chatbot_history], outputs=[chatbot_output, chatbot_history])

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)