imagecaptionerpretrained / src /streamlit_app.py
mazen2100's picture
Update src/streamlit_app.py
1472f98 verified
import streamlit as st
import torch
import numpy as np
from PIL import Image, ImageEnhance
import io
import requests
from transformers import (
BlipForConditionalGeneration,
BlipProcessor,
VisionEncoderDecoderModel,
ViTImageProcessor,
AutoTokenizer,
CLIPProcessor,
CLIPModel,
AutoModelForCausalLM,
AutoProcessor
)
from deep_translator import GoogleTranslator
from scipy.ndimage import variance
from concurrent.futures import ThreadPoolExecutor
# CONFIGURATION
st.set_page_config(
page_title="πŸ–ΌοΈ AI Image Caption Generator",
layout="wide",
initial_sidebar_state="expanded"
)
# Define model configurations
MODEL_CONFIGS = {
"BLIP": {
"name": "BLIP",
"icon": "✍️",
"description": "BLIP excels at generating detailed and accurate image descriptions using vision-language pre-training.",
"generate_params": {"max_length": 50, "num_beams": 5, "min_length": 10, "top_p": 0.9, "repetition_penalty": 1.5}
},
"ViT-GPT2": {
"name": "ViT-GPT2",
"icon": "πŸ”Ž",
"description": "ViT-GPT2 combines Vision Transformer with GPT2 for fluent and consistent image captions.",
"generate_params": {"max_length": 50, "num_beams": 5, "min_length": 10, "repetition_penalty": 1.5}
},
"GIT": {
"name": "GIT-base",
"icon": "πŸ“ˆ",
"description": "GIT generates contextually relevant captions with a focus on scene understanding.",
"generate_params": {"max_length": 50, "num_beams": 4, "min_length": 8, "repetition_penalty": 1.5}
},
"CLIP": {
"name": "CLIP",
"icon": "🎨",
"description": "CLIP provides comprehensive image analysis with confidence scores across content, scene, and style.",
}
}
# LOADING FUNCTIONS
@st.cache_resource
def load_blip_model():
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
if torch.cuda.is_available():
model = model.to("cuda")
return model, processor
@st.cache_resource
def load_vit_gpt2_model():
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
if torch.cuda.is_available():
model = model.to("cuda")
return model, feature_extractor, tokenizer
@st.cache_resource
def load_git_model():
processor = AutoProcessor.from_pretrained("microsoft/git-base")
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
if torch.cuda.is_available():
model = model.to("cuda")
return model, processor
@st.cache_resource
def load_clip_model():
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
if torch.cuda.is_available():
model = model.to("cuda")
return model, processor
# IMAGE PROCESSING
def preprocess_image(image):
max_size = 1024
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.LANCZOS)
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(1.2)
img_array = np.array(image.convert('L'))
if np.mean(img_array) < 100:
brightness_enhancer = ImageEnhance.Brightness(image)
image = brightness_enhancer.enhance(1.3)
return image
def check_image_quality(image):
if image.width < 200 or image.height < 200:
return False, "Image is too small for accurate captioning. Consider using a larger image."
img_array = np.array(image.convert('L'))
if variance(img_array) < 100:
return False, "Image may be too blurry for accurate captioning. Consider using a clearer image."
return True, "Image quality is sufficient for captioning."
# CAPTION GENERATION FUNCTIONS
def generate_caption(image, model_name, models_data):
if model_name == "BLIP":
model, processor = models_data[model_name]
return get_blip_caption(image, model, processor)
elif model_name == "ViT-GPT2":
model, feature_extractor, tokenizer = models_data[model_name]
return get_vit_gpt2_caption(image, model, feature_extractor, tokenizer)
elif model_name == "GIT":
model, processor = models_data[model_name]
return get_git_caption(image, model, processor)
elif model_name == "CLIP":
model, processor = models_data[model_name]
return get_clip_caption(image, model, processor)
return "Model not supported"
def get_blip_caption(image, model, processor):
try:
inputs = processor(image, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
output = model.generate(**inputs, **MODEL_CONFIGS["BLIP"]["generate_params"])
caption = processor.decode(output[0], skip_special_tokens=True)
return caption
except Exception as e:
return f"BLIP model error: {str(e)}"
def get_vit_gpt2_caption(image, model, feature_extractor, tokenizer):
try:
inputs = feature_extractor(images=image, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
output = model.generate(**inputs, **MODEL_CONFIGS["ViT-GPT2"]["generate_params"])
caption = tokenizer.decode(output[0], skip_special_tokens=True)
return caption
except Exception as e:
return f"ViT-GPT2 model error: {str(e)}"
def get_git_caption(image, model, processor):
try:
inputs = processor(images=image, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
output = model.generate(**inputs, **MODEL_CONFIGS["GIT"]["generate_params"])
caption = processor.decode(output[0], skip_special_tokens=True)
return caption
except Exception as e:
return f"GIT model error: {str(e)}"
CONTENT_CATEGORIES = [
"a portrait photograph", "a landscape photograph", "a wildlife photograph",
"an architectural photograph", "a street photograph", "a food photograph",
"a fashion photograph", "a sports photograph", "a macro photograph",
"a night photograph", "an aerial photograph", "an underwater photograph",
"a product photograph", "a documentary photograph", "a travel photograph",
"a black and white photograph", "an abstract photograph", "a concert photograph",
"a wedding photograph", "a nature photograph"
]
SCENE_ATTRIBUTES = [
"indoors", "outdoors", "daytime", "nighttime", "urban", "rural",
"beach", "mountains", "forest", "desert", "snowy", "rainy",
"foggy", "sunny", "crowded", "empty", "modern", "vintage",
"colorful", "minimalist"
]
STYLE_ATTRIBUTES = [
"professional", "casual", "artistic", "documentary", "aerial view",
"close-up", "wide-angle", "telephoto", "panoramic", "HDR",
"long exposure", "shallow depth of field", "silhouette", "motion blur"
]
def get_clip_caption(image, model, processor):
try:
content_inputs = processor(text=CONTENT_CATEGORIES, images=image, return_tensors="pt", padding=True)
if torch.cuda.is_available():
content_inputs = {k: v.to("cuda") for k, v in content_inputs.items() if torch.is_tensor(v)}
content_outputs = model(**content_inputs)
content_probs = content_outputs.logits_per_image.softmax(dim=1)[0]
top_content_probs, top_content_indices = torch.topk(content_probs, 2)
scene_inputs = processor(text=SCENE_ATTRIBUTES, images=image, return_tensors="pt", padding=True)
if torch.cuda.is_available():
scene_inputs = {k: v.to("cuda") for k, v in scene_inputs.items() if torch.is_tensor(v)}
scene_outputs = model(**scene_inputs)
scene_probs = scene_outputs.logits_per_image.softmax(dim=1)[0]
top_scene_probs, top_scene_indices = torch.topk(scene_probs, 2)
style_inputs = processor(text=STYLE_ATTRIBUTES, images=image, return_tensors="pt", padding=True)
if torch.cuda.is_available():
style_inputs = {k: v.to("cuda") for k, v in style_inputs.items() if torch.is_tensor(v)}
style_outputs = model(**style_inputs)
style_probs = style_outputs.logits_per_image.softmax(dim=1)[0]
top_style_probs, top_style_indices = torch.topk(style_probs, 1)
primary_content = CONTENT_CATEGORIES[top_content_indices[0].item()].replace("a ", "")
primary_scene = SCENE_ATTRIBUTES[top_scene_indices[0].item()]
primary_style = STYLE_ATTRIBUTES[top_style_indices[0].item()]
secondary_elements = []
if top_content_probs[1].item() > 0.15:
secondary_content = CONTENT_CATEGORIES[top_content_indices[1].item()].replace("a ", "")
secondary_elements.append(f"with elements of {secondary_content}")
if top_scene_probs[1].item() > 0.15:
secondary_scene = SCENE_ATTRIBUTES[top_scene_indices[1].item()]
secondary_elements.append(f"also showing {secondary_scene} characteristics")
detailed_caption = f"This appears to be {CONTENT_CATEGORIES[top_content_indices[0].item()]} captured in a {primary_scene} setting"
if secondary_elements:
detailed_caption += ", " + " ".join(secondary_elements)
detailed_caption += f". The image has a {primary_style} quality to it."
detailed_caption += f" (Primary content: {top_content_probs[0].item()*100:.1f}% confidence)"
return detailed_caption
except Exception as e:
return f"CLIP model error: {str(e)}"
# TRANSLATION FUNCTION
def batch_translate(texts, target_lang):
try:
translator = GoogleTranslator(source='en', target=target_lang)
return {key: translator.translate(value) for key, value in texts.items()}
except Exception as e:
return {key: f"Translation error: {str(e)}" for key in texts}
# MAIN APPLICATION
def main():
# Custom CSS for modern dark mode and no shapes under titles
st.markdown("""
<style>
body {
background-color: #0f0f23;
color: #d1d1e0;
}
.main-header {
font-size: 2.8rem;
color: #ff6b6b;
text-align: center;
margin-bottom: 1.5rem;
font-weight: 700;
text-shadow: 1px 1px 6px rgba(255, 107, 107, 0.4);
}
.sub-header {
font-size: 1.6rem;
color: #4ecdc4;
margin-bottom: 1rem;
font-weight: 600;
padding: 0;
background-color: transparent;
border: none;
}
.info-text {
font-size: 1.1rem;
background-color: #1a1a38;
padding: 15px;
border-radius: 8px;
margin-bottom: 15px;
border: 1px solid #2a2a52;
color: #a3b8ff;
}
.stButton>button {
width: 100%;
background-color: #ff6b6b;
color: #0f0f23;
border-radius: 6px;
padding: 10px;
font-size: 1.1rem;
font-weight: 500;
transition: background-color 0.3s, transform 0.2s;
}
.stButton>button:hover {
background-color: #ff8787;
transform: translateY(-2px);
}
.caption-card {
background-color: #1f2a44;
padding: 15px;
border-radius: 8px;
margin-bottom: 12px;
border: 1px solid #2a2a52;
box-shadow: 0 3px 10px rgba(0,0,0,0.3);
color: #d1d1e0;
font-size: 1.2rem;
transition: transform 0.2s;
}
.caption-card:hover {
transform: translateY(-2px);
}
.model-badge {
display: inline-block;
padding: 4px 10px;
border-radius: 12px;
font-size: 0.8rem;
margin-left: 10px;
background-color: #4ecdc4;
color: #0f0f23;
}
.caption-comparison {
background-color: #1a1a38;
padding: 15px;
border-radius: 8px;
margin-bottom: 15px;
border: 1px solid #2a2a52;
box-shadow: 0 3px 10px rgba(0,0,0,0.3);
}
.comparison-model-name {
font-weight: 600;
color: #ff6b6b;
margin-bottom: 6px;
font-size: 1.2rem;
}
.comparison-caption {
padding: 10px;
background-color: #1f2a44;
border-radius: 6px;
margin-bottom: 10px;
color: #d1d1e0;
font-size: 1.2rem;
border: 1px solid #2a2a52;
}
.tab-content {
padding: 15px 0;
}
.input-container {
background-color: transparent;
padding: 0;
margin-bottom: 15px;
border: none;
}
.image-container {
border-radius: 8px;
overflow: hidden;
box-shadow: 0 3px 10px rgba(0,0,0,0.3);
background-color: #0f0f23;
}
.model-selection-container {
background-color: #1a1a38;
padding: 15px;
border-radius: 8px;
border: 1px solid #2a2a52;
box-shadow: 0 3px 10px rgba(0,0,0,0.3);
}
.sidebar-content {
background-color: #1a1a38;
padding: 15px;
border-radius: 8px;
border: 1px solid #2a2a52;
margin-bottom: 15px;
}
.sidebar-header {
font-size: 1.8rem;
color: #ff6b6b;
font-weight: 700;
margin-bottom: 1rem;
}
.sidebar-section {
margin-bottom: 1.2rem;
}
.stExpander {
background-color: transparent;
border: none;
}
.stExpander > div > div {
background-color: #1a1a38;
border: 1px solid #2a2a52;
border-radius: 8px;
padding: 10px;
}
.stExpander > label {
color: #4ecdc4;
font-size: 1.2rem;
font-weight: 600;
background-color: transparent;
border: none;
}
.stRadio > div {
background-color: transparent;
border: none;
padding: 0;
margin: 0;
display: flex;
gap: 15px;
}
.stRadio > div > label {
color: #d1d1e0;
font-size: 1rem;
font-weight: 500;
background-color: #1f2a44;
padding: 8px 15px;
border-radius: 6px;
transition: background-color 0.3s;
}
.stRadio > div > label:hover {
background-color: #2a2a52;
}
.stFileUploader > div {
background-color: transparent;
}
.stTextInput > div {
background-color: #1f2a44;
border: 1px solid #2a2a52;
border-radius: 6px;
}
.stTextInput > div > div > input {
color: #d1d1e0;
}
</style>
""", unsafe_allow_html=True)
st.markdown('<h1 class="main-header">🌌 AI Image Caption Generator</h1>', unsafe_allow_html=True)
st.markdown("""
<div class="info-text">
Upload an image or provide a URL to generate and translate captions using advanced AI models. Compare results across multiple models.
</div>
""", unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.markdown('<div class="sidebar-content">', unsafe_allow_html=True)
st.markdown('<h2 class="sidebar-header">πŸ“˜ About This App</h2>', unsafe_allow_html=True)
st.markdown('<div class="sidebar-section">', unsafe_allow_html=True)
st.markdown("""
This NLP project uses cutting-edge AI models to generate and translate image captions with high accuracy.
""")
st.markdown('</div>', unsafe_allow_html=True)
st.markdown('<div class="sidebar-section">', unsafe_allow_html=True)
st.markdown('<h3 class="sub-header">πŸ› οΈ Models Used:</h3>')
st.markdown("""
- **BLIP**: Detailed and accurate descriptions
- **ViT-GPT2**: Fluent and consistent captions
- **GIT**: Contextually relevant descriptions
- **CLIP**: Comprehensive image analysis
""")
st.markdown('</div>', unsafe_allow_html=True)
st.markdown('<div class="sidebar-section">', unsafe_allow_html=True)
st.markdown('<h3 class="sub-header">πŸ”§ Technologies:</h3>')
st.markdown("""
- Streamlit
- Hugging Face Transformers
- PyTorch
- Google Translator API
""")
st.markdown('</div>', unsafe_allow_html=True)
with st.expander("πŸ“Š Model Comparison"):
st.markdown("""
| Model | Strengths | Best For |
|---------|----------------------|---------------------|
| BLIP | Detailed, accurate | General captioning |
| ViT-GPT2| Efficient, consistent| Quick descriptions |
| GIT | Contextually relevant| Scene understanding |
| CLIP | Classification-based | Image type analysis |
""")
st.markdown('</div>', unsafe_allow_html=True)
# Image Input Section (Full Width)
with st.container():
st.markdown('<h2 class="sub-header">πŸŒ„ Image Input</h2>', unsafe_allow_html=True)
with st.container():
st.markdown('<div class="input-container">', unsafe_allow_html=True)
input_option = st.radio("Choose input method:", ["Upload Image", "Image URL"], horizontal=True)
image = None
if input_option == "Upload Image":
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"], label_visibility="collapsed")
if uploaded_file is not None:
try:
image = Image.open(uploaded_file).convert("RGB")
except Exception as e:
st.error(f"Error opening image: {e}")
else:
url = st.text_input("Enter Image URL", placeholder="https://example.com/image.jpg", label_visibility="collapsed")
if url:
try:
response = requests.get(url)
if response.status_code == 200 and 'image' in response.headers.get('Content-Type', ''):
image = Image.open(io.BytesIO(response.content)).convert("RGB")
else:
st.error("Invalid image URL or content type.")
except Exception as e:
st.error(f"Error loading image from URL: {e}")
st.markdown('</div>', unsafe_allow_html=True)
# Image Display and Model Selection (Two Columns)
if image:
with st.container():
col_image, col_models = st.columns([3, 2])
with col_image:
with st.spinner("Processing image..."):
quality_ok, quality_message = check_image_quality(image)
if not quality_ok:
st.warning(quality_message)
processed_image = preprocess_image(image)
st.markdown('<div class="image-container">', unsafe_allow_html=True)
st.image(processed_image, caption="Image for Captioning", use_container_width=True)
st.markdown('</div>', unsafe_allow_html=True)
with col_models:
st.markdown('<h2 class="sub-header">βš™οΈ Select Models</h2>', unsafe_allow_html=True)
with st.container():
st.markdown('<div class="model-selection-container">', unsafe_allow_html=True)
use_blip = st.checkbox("BLIP (Bootstrapping Language-Image Pre-training)", value=True)
use_vit_gpt2 = st.checkbox("ViT-GPT2 (Vision Transformer with GPT2)", value=True)
use_git = st.checkbox("GIT (Generative Image-to-text Transformer)", value=True)
use_clip = st.checkbox("CLIP (Contrastive Language-Image Pre-training)", value=True)
with st.expander("πŸ”§ Advanced Options"):
translation_language = st.selectbox(
"Translation Language",
["Arabic", "French", "Spanish", "Chinese", "Russian", "German"],
index=0
)
language_code_map = {
"Arabic": "ar", "French": "fr", "Spanish": "es",
"Chinese": "zh", "Russian": "ru", "German": "de"
}
selected_lang_code = language_code_map[translation_language]
st.markdown("<br>", unsafe_allow_html=True)
generate_button = st.button("Generate Captions", type="primary")
st.markdown('</div>', unsafe_allow_html=True)
# Generate Captions
if generate_button:
selected_models = []
if use_blip:
selected_models.append("BLIP")
if use_vit_gpt2:
selected_models.append("ViT-GPT2")
if use_git:
selected_models.append("GIT")
if use_clip:
selected_models.append("CLIP")
if not selected_models:
st.warning("Please select at least one model.")
else:
with st.spinner("Loading models..."):
models_data = {}
if use_blip:
models_data["BLIP"] = load_blip_model()
if use_vit_gpt2:
models_data["ViT-GPT2"] = load_vit_gpt2_model()
if use_git:
models_data["GIT"] = load_git_model()
if use_clip:
models_data["CLIP"] = load_clip_model()
with st.spinner("Generating captions... This may take a moment"):
captions = {}
with ThreadPoolExecutor(max_workers=min(len(selected_models), 4)) as executor:
future_to_model = {
executor.submit(generate_caption, processed_image, model_name, models_data): model_name
for model_name in selected_models
}
for future in future_to_model:
model_name = future_to_model[future]
try:
caption = future.result()
captions[model_name] = caption
except Exception as e:
captions[model_name] = f"Error generating caption: {str(e)}"
with st.spinner(f"Translating to {translation_language}..."):
translations = batch_translate(captions, selected_lang_code)
# Display Captions
st.markdown('<h2 class="sub-header">πŸ“ Generated Captions</h2>', unsafe_allow_html=True)
model_colors = {
"BLIP": "#2a2a52",
"ViT-GPT2": "#2a3852",
"GIT": "#2a2a52",
"CLIP": "#2a3852"
}
tabs = st.tabs([f"{MODEL_CONFIGS[model_name]['icon']} {model_name}" for model_name in captions])
rtl_languages = ["ar"]
text_dir = "rtl" if selected_lang_code in rtl_languages else "ltr"
for i, model_name in enumerate(captions):
with tabs[i]:
st.markdown('<div class="tab-content">', unsafe_allow_html=True)
eng_col, trans_col = st.columns(2)
with eng_col:
st.markdown(f"**πŸ‡¬πŸ‡§ English Caption:**")
st.markdown(f"""
<div class="caption-card" style="background-color: {model_colors[model_name]};">
{captions[model_name]}
</div>
""", unsafe_allow_html=True)
with trans_col:
lang_flags = {
"ar": "πŸ‡ΈπŸ‡¦", "fr": "πŸ‡«πŸ‡·", "es": "πŸ‡ͺπŸ‡Έ",
"zh": "πŸ‡¨πŸ‡³", "ru": "πŸ‡·πŸ‡Ί", "de": "πŸ‡©πŸ‡ͺ"
}
st.markdown(f"**{lang_flags.get(selected_lang_code, '🌐')} {translation_language} Translation:**")
st.markdown(f"""
<div class="caption-card" style="background-color: {model_colors[model_name]};" dir="{text_dir}">
{translations[model_name]}
</div>
""", unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
with st.expander("ℹ️ About this model"):
st.markdown(MODEL_CONFIGS[model_name]["description"])
# Caption Comparison
if len(captions) > 1:
with st.expander("πŸ” Compare All Captions", expanded=True):
st.markdown('<div class="caption-comparison">', unsafe_allow_html=True)
for model_name, caption in captions.items():
st.markdown(f"""
<div class="comparison-model-name">
{MODEL_CONFIGS[model_name]['icon']} {model_name}
</div>
<div class="comparison-caption" style="background-color: {model_colors[model_name]};">
{caption}
</div>
""", unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
if __name__ == "__main__":
main()