File size: 4,357 Bytes
432150c
 
 
 
 
 
 
17a46e3
432150c
 
 
 
2a7d6f3
 
 
 
 
 
 
 
 
 
432150c
f07f9e1
432150c
 
17a46e3
432150c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd767c7
 
 
 
 
 
 
 
432150c
 
17a46e3
432150c
 
 
bd767c7
 
 
 
 
432150c
 
 
 
 
 
 
 
f07f9e1
432150c
 
cf451d4
432150c
 
34f96cc
17a46e3
 
 
 
 
 
 
d35fb6c
17a46e3
 
34f96cc
 
 
 
232aac7
34f96cc
 
 
17a46e3
 
 
f1aed79
010d6dc
17a46e3
 
2a7d6f3
 
17a46e3
 
 
 
34f96cc
17a46e3
 
 
 
 
 
34f96cc
17a46e3
 
 
f1aed79
6cc0902
432150c
 
 
34f96cc
432150c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from typing import Any
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer

# Constants
DEFAULT_PARAMS = {
    "do_sample": False,
    "max_new_tokens": 256,
}
DEFAULT_QUERY = (
    "Provide a factual description of this image in up to two paragraphs. "
    "Include details on objects, background, scenery, interactions, gestures, poses, and any visible text content. "
    "Specify the number of repeated objects. "
    "Describe the dominant colors, color contrasts, textures, and materials. "
    "Mention the composition, including the arrangement of elements and focus points. "
    "Note the camera angle or perspective, and provide any identifiable contextual information. "
    "Include details on the style, lighting, and shadows. "
    "Avoid subjective interpretations or speculation."
)

DTYPE = torch.bfloat16
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model and tokenizer
tokenizer = LlamaTokenizer.from_pretrained(
    pretrained_model_name_or_path="lmsys/vicuna-7b-v1.5",
)
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path="THUDM/cogvlm-chat-hf",
    torch_dtype=DTYPE,
    trust_remote_code=True,
    low_cpu_mem_usage=True,
)

model = model.to(device=DEVICE)

@spaces.GPU
@torch.no_grad()
def generate_caption(
    image: Image.Image,
    params: dict[str, Any] = DEFAULT_PARAMS,
) -> str:
    # Debugging: Check image size and format
    print(f"Uploaded image format: {image.format}, size: {image.size}")
    
    # Convert image to the expected format (if needed)
    if image.mode != "RGB":
        image = image.convert("RGB")
        print(f"Image converted to RGB mode: {image.mode}")

    inputs = model.build_conversation_input_ids(
        tokenizer=tokenizer,
        query=DEFAULT_QUERY,
        history=[],
        images=[image],
    )
    
    # Debugging: Check tensor shapes
    print(f"Input IDs shape: {inputs['input_ids'].shape}")
    print(f"Images tensor shape: {inputs['images'][0].shape}")

    inputs = {
        "input_ids": inputs["input_ids"].unsqueeze(0).to(device=DEVICE),
        "token_type_ids": inputs["token_type_ids"].unsqueeze(0).to(device=DEVICE),
        "attention_mask": inputs["attention_mask"].unsqueeze(0).to(device=DEVICE),
        "images": [[inputs["images"][0].to(device=DEVICE, dtype=DTYPE)]],
    }

    outputs = model.generate(**inputs, **params)
    outputs = outputs[:, inputs["input_ids"].shape[1] :]
    result = tokenizer.decode(outputs[0])

    result = result.replace("This image showcases", "").strip().removesuffix("</s>").strip().capitalize()
    return result

# CSS for design enhancements with a fixed image input bar and simplified query
css = """
  #container {
    background-color: #f9f9f9;
    padding: 20px;
    border-radius: 15px;
    border: 2px solid #333; /* Darker outline */
    box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); /* Enhanced shadow */
    max-width: 450px;
    margin: auto;
  }
  #input_image {
    margin-top: 15px;
    border: 2px solid #333; /* Darker outline */
    border-radius: 8px;
    height: 180px; /* Fixed height */
    object-fit: contain; /* Ensure image fits within the fixed height */
  }
  #output_caption {
    margin-top: 15px;
    border: 2px solid #333; /* Darker outline */
    border-radius: 8px;
    height: 180px; /* Fixed height */
    overflow-y: auto; /* Scrollable if content exceeds height */
  }
  #run_button {
    background-color: #fff; /* Dark button color */
    color: black; /* White text */
    border-radius: 10px;
    padding: 10px;
    cursor: pointer;
    transition: background-color 0.3s ease;
    margin-top: 15px;
  }
  #run_button:hover {
    background-color: #333; /* Slightly lighter on hover */
  }
"""

# Gradio interface with vertical alignment and fixed image input height
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="container"):
        input_image = gr.Image(type="pil", elem_id="input_image")
        run_button = gr.Button(value="Generate Prompt", elem_id="run_button")
        output_caption = gr.Textbox(label="Womener AI", show_copy_button=True, elem_id="output_caption", lines=6)

    run_button.click(
        fn=generate_caption,
        inputs=[input_image],
        outputs=output_caption,
    )

demo.launch(share=False)