Captain / app.py
mrbeliever's picture
Update app.py
232aac7 verified
raw
history blame
4.36 kB
from typing import Any
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer
# Constants
DEFAULT_PARAMS = {
"do_sample": False,
"max_new_tokens": 256,
}
DEFAULT_QUERY = (
"Provide a factual description of this image in up to two paragraphs. "
"Include details on objects, background, scenery, interactions, gestures, poses, and any visible text content. "
"Specify the number of repeated objects. "
"Describe the dominant colors, color contrasts, textures, and materials. "
"Mention the composition, including the arrangement of elements and focus points. "
"Note the camera angle or perspective, and provide any identifiable contextual information. "
"Include details on the style, lighting, and shadows. "
"Avoid subjective interpretations or speculation."
)
DTYPE = torch.bfloat16
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and tokenizer
tokenizer = LlamaTokenizer.from_pretrained(
pretrained_model_name_or_path="lmsys/vicuna-7b-v1.5",
)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path="THUDM/cogvlm-chat-hf",
torch_dtype=DTYPE,
trust_remote_code=True,
low_cpu_mem_usage=True,
)
model = model.to(device=DEVICE)
@spaces.GPU
@torch.no_grad()
def generate_caption(
image: Image.Image,
params: dict[str, Any] = DEFAULT_PARAMS,
) -> str:
# Debugging: Check image size and format
print(f"Uploaded image format: {image.format}, size: {image.size}")
# Convert image to the expected format (if needed)
if image.mode != "RGB":
image = image.convert("RGB")
print(f"Image converted to RGB mode: {image.mode}")
inputs = model.build_conversation_input_ids(
tokenizer=tokenizer,
query=DEFAULT_QUERY,
history=[],
images=[image],
)
# Debugging: Check tensor shapes
print(f"Input IDs shape: {inputs['input_ids'].shape}")
print(f"Images tensor shape: {inputs['images'][0].shape}")
inputs = {
"input_ids": inputs["input_ids"].unsqueeze(0).to(device=DEVICE),
"token_type_ids": inputs["token_type_ids"].unsqueeze(0).to(device=DEVICE),
"attention_mask": inputs["attention_mask"].unsqueeze(0).to(device=DEVICE),
"images": [[inputs["images"][0].to(device=DEVICE, dtype=DTYPE)]],
}
outputs = model.generate(**inputs, **params)
outputs = outputs[:, inputs["input_ids"].shape[1] :]
result = tokenizer.decode(outputs[0])
result = result.replace("This image showcases", "").strip().removesuffix("</s>").strip().capitalize()
return result
# CSS for design enhancements with a fixed image input bar and simplified query
css = """
#container {
background-color: #f9f9f9;
padding: 20px;
border-radius: 15px;
border: 2px solid #333; /* Darker outline */
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); /* Enhanced shadow */
max-width: 500px;
margin: auto;
}
#input_image {
margin-top: 15px;
border: 2px solid #333; /* Darker outline */
border-radius: 8px;
height: 180px; /* Fixed height */
object-fit: contain; /* Ensure image fits within the fixed height */
}
#output_caption {
margin-top: 15px;
border: 2px solid #333; /* Darker outline */
border-radius: 8px;
height: 200px; /* Fixed height */
overflow-y: auto; /* Scrollable if content exceeds height */
}
#run_button {
background-color: #fff; /* Dark button color */
color: black; /* White text */
border-radius: 10px;
padding: 10px;
cursor: pointer;
transition: background-color 0.3s ease;
margin-top: 15px;
}
#run_button:hover {
background-color: #333; /* Slightly lighter on hover */
}
"""
# Gradio interface with vertical alignment and fixed image input height
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="container"):
input_image = gr.Image(type="pil", elem_id="input_image")
run_button = gr.Button(value="Generate Caption", elem_id="run_button")
output_caption = gr.Textbox(label="Womener AI", show_copy_button=True, elem_id="output_caption" lines=6)
run_button.click(
fn=generate_caption,
inputs=[input_image],
outputs=output_caption,
)
demo.launch(share=False)