Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,357 Bytes
432150c 17a46e3 432150c 2a7d6f3 432150c f07f9e1 432150c 17a46e3 432150c bd767c7 432150c 17a46e3 432150c bd767c7 432150c f07f9e1 432150c cf451d4 432150c 34f96cc 17a46e3 d35fb6c 17a46e3 34f96cc 232aac7 34f96cc 17a46e3 f1aed79 010d6dc 17a46e3 2a7d6f3 17a46e3 34f96cc 17a46e3 34f96cc 17a46e3 f1aed79 6cc0902 432150c 34f96cc 432150c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from typing import Any
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer
# Constants
DEFAULT_PARAMS = {
"do_sample": False,
"max_new_tokens": 256,
}
DEFAULT_QUERY = (
"Provide a factual description of this image in up to two paragraphs. "
"Include details on objects, background, scenery, interactions, gestures, poses, and any visible text content. "
"Specify the number of repeated objects. "
"Describe the dominant colors, color contrasts, textures, and materials. "
"Mention the composition, including the arrangement of elements and focus points. "
"Note the camera angle or perspective, and provide any identifiable contextual information. "
"Include details on the style, lighting, and shadows. "
"Avoid subjective interpretations or speculation."
)
DTYPE = torch.bfloat16
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and tokenizer
tokenizer = LlamaTokenizer.from_pretrained(
pretrained_model_name_or_path="lmsys/vicuna-7b-v1.5",
)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path="THUDM/cogvlm-chat-hf",
torch_dtype=DTYPE,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
model = model.to(device=DEVICE)
@spaces.GPU
@torch.no_grad()
def generate_caption(
image: Image.Image,
params: dict[str, Any] = DEFAULT_PARAMS,
) -> str:
# Debugging: Check image size and format
print(f"Uploaded image format: {image.format}, size: {image.size}")
# Convert image to the expected format (if needed)
if image.mode != "RGB":
image = image.convert("RGB")
print(f"Image converted to RGB mode: {image.mode}")
inputs = model.build_conversation_input_ids(
tokenizer=tokenizer,
query=DEFAULT_QUERY,
history=[],
images=[image],
)
# Debugging: Check tensor shapes
print(f"Input IDs shape: {inputs['input_ids'].shape}")
print(f"Images tensor shape: {inputs['images'][0].shape}")
inputs = {
"input_ids": inputs["input_ids"].unsqueeze(0).to(device=DEVICE),
"token_type_ids": inputs["token_type_ids"].unsqueeze(0).to(device=DEVICE),
"attention_mask": inputs["attention_mask"].unsqueeze(0).to(device=DEVICE),
"images": [[inputs["images"][0].to(device=DEVICE, dtype=DTYPE)]],
}
outputs = model.generate(**inputs, **params)
outputs = outputs[:, inputs["input_ids"].shape[1] :]
result = tokenizer.decode(outputs[0])
result = result.replace("This image showcases", "").strip().removesuffix("</s>").strip().capitalize()
return result
# CSS for design enhancements with a fixed image input bar and simplified query
css = """
#container {
background-color: #f9f9f9;
padding: 20px;
border-radius: 15px;
border: 2px solid #333; /* Darker outline */
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); /* Enhanced shadow */
max-width: 450px;
margin: auto;
}
#input_image {
margin-top: 15px;
border: 2px solid #333; /* Darker outline */
border-radius: 8px;
height: 180px; /* Fixed height */
object-fit: contain; /* Ensure image fits within the fixed height */
}
#output_caption {
margin-top: 15px;
border: 2px solid #333; /* Darker outline */
border-radius: 8px;
height: 180px; /* Fixed height */
overflow-y: auto; /* Scrollable if content exceeds height */
}
#run_button {
background-color: #fff; /* Dark button color */
color: black; /* White text */
border-radius: 10px;
padding: 10px;
cursor: pointer;
transition: background-color 0.3s ease;
margin-top: 15px;
}
#run_button:hover {
background-color: #333; /* Slightly lighter on hover */
}
"""
# Gradio interface with vertical alignment and fixed image input height
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="container"):
input_image = gr.Image(type="pil", elem_id="input_image")
run_button = gr.Button(value="Generate Prompt", elem_id="run_button")
output_caption = gr.Textbox(label="Womener AI", show_copy_button=True, elem_id="output_caption", lines=6)
run_button.click(
fn=generate_caption,
inputs=[input_image],
outputs=output_caption,
)
demo.launch(share=False)
|