mazen2100's picture
Update src/app.py
fe92329 verified
import os
# Fix for permissions on Hugging Face Spaces
os.environ['TRANSFORMERS_CACHE'] = '/tmp/hf'
os.environ['HF_HOME'] = '/tmp/hf'
os.environ['XDG_CACHE_HOME'] = '/tmp'
os.environ['STREAMLIT_HOME'] = '/tmp'
os.makedirs('/tmp/hf', exist_ok=True)
import streamlit as st
import torch
import numpy as np
from PIL import Image, ImageEnhance
import io
import requests
from transformers import (
BlipForConditionalGeneration,
BlipProcessor,
VisionEncoderDecoderModel,
ViTImageProcessor,
AutoTokenizer,
CLIPProcessor,
CLIPModel,
AutoModelForCausalLM,
AutoProcessor
)
from deep_translator import GoogleTranslator
from scipy.ndimage import variance
from concurrent.futures import ThreadPoolExecutor
# ......................... PAGE CONFIGURATION ..........................
st.set_page_config(
page_title="πŸ–ΌοΈ AI Image Caption Generator",
layout="wide",
initial_sidebar_state="expanded"
)
# .......................... MODEL CONFIGURATION ....................
MODEL_CONFIGS = {
"BLIP": {
"name": "BLIP",
"icon": "⭐",
"description": "BLIP (Bootstrapping Language-Image Pre-training) is designed to learn vision-language representation from noisy web data. It excels at generating detailed and accurate image descriptions.",
"generate_params": {"max_length": 50, "num_beams": 5, "min_length": 10, "do_sample": True, "top_p": 0.9, "repetition_penalty": 1.5} # Added do_sample=True
},
"ViT-GPT2": {
"name": "ViT-GPT2",
"icon": "⭐",
"description": "ViT-GPT2 combines Vision Transformer for image encoding with GPT2 for text generation. It's effective at capturing visual details and creating fluent natural language descriptions.",
"generate_params": {"max_length": 50, "num_beams": 5, "min_length": 10, "repetition_penalty": 1.5}
},
"GIT": {
"name": "GIT-base",
"icon": "⭐",
"description": "GIT (Generative Image-to-text Transformer) is designed specifically for image captioning tasks, focusing on generating coherent and contextually relevant descriptions.",
"generate_params": {"max_length": 50, "num_beams": 4, "min_length": 8, "repetition_penalty": 1.5}
},
"CLIP": {
"name": "CLIP",
"icon": "⭐",
"description": "CLIP (Contrastive Language-Image Pre-training) analyzes images across multiple dimensions including content type, scene attributes, and photographic style.",
}
}
# ......................... LOADING FUNCTIONS .....................................
@st.cache_resource
def load_blip_model():
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") # Changed to base model
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
return model, processor
@st.cache_resource
def load_vit_gpt2_model():
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
return model, feature_extractor, tokenizer
@st.cache_resource
def load_git_model():
processor = AutoProcessor.from_pretrained("microsoft/git-base")
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
return model, processor
@st.cache_resource
def load_clip_model():
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") # Changed to smaller model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
return model, processor
# ......................... IMAGE PROCESSING ...............................
def preprocess_image(image):
max_size = 1024
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.LANCZOS)
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(1.2)
img_array = np.array(image.convert('L'))
if np.mean(img_array) < 100:
brightness_enhancer = ImageEnhance.Brightness(image)
image = brightness_enhancer.enhance(1.3)
return image
def check_image_quality(image):
if image.width < 200 or image.height < 200:
return False, "Image is too small. Please use a bigger image."
img_array = np.array(image.convert('L'))
if variance(img_array) < 100:
return False, "Image might be too blurry. Please use a clearer image."
return True, "Image looks good for captioning."
# .................... CAPTION GENERATION FUNCTIONS ..................
def generate_caption(image, model_name, models_data):
if model_name == "BLIP":
model, processor = models_data[model_name]
return get_blip_caption(image, model, processor)
elif model_name == "ViT-GPT2":
model, feature_extractor, tokenizer = models_data[model_name]
return get_vit_gpt2_caption(image, model, feature_extractor, tokenizer)
elif model_name == "GIT":
model, processor = models_data[model_name]
return get_git_caption(image, model, processor)
elif model_name == "CLIP":
model, processor = models_data[model_name]
return get_clip_caption(image, model, processor)
return "Model not supported"
def get_blip_caption(image, model, processor):
try:
inputs = processor(images=image, return_tensors="pt", padding=True, truncation=True)
output = model.generate(**inputs, **MODEL_CONFIGS["BLIP"]["generate_params"])
caption = processor.decode(output[0], skip_special_tokens=True)
return caption
except Exception as e:
return f"BLIP model error: {str(e)}"
def get_vit_gpt2_caption(image, model, feature_extractor, tokenizer):
try:
inputs = feature_extractor(images=image, return_tensors="pt", padding=True)
output = model.generate(
pixel_values=inputs.pixel_values,
**MODEL_CONFIGS["ViT-GPT2"]["generate_params"],
attention_mask=inputs.attention_mask if hasattr(inputs, "attention_mask") else None
)
caption = tokenizer.decode(output[0], skip_special_tokens=True)
return caption
except Exception as e:
return f"ViT-GPT2 model error: {str(e)}"
def get_git_caption(image, model, processor):
try:
inputs = processor(images=image, return_tensors="pt", padding=True)
output = model.generate(**inputs, **MODEL_CONFIGS["GIT"]["generate_params"])
caption = processor.decode(output[0], skip_special_tokens=True)
return caption
except Exception as e:
return f"GIT model error: {str(e)}"
# .................... CLIP CATEGORIES ................
CONTENT_CATEGORIES = [
"a portrait photograph", "a landscape photograph", "a wildlife photograph",
"an architectural photograph", "a street photograph", "a food photograph",
"a fashion photograph", "a sports photograph", "a macro photograph",
"a night photograph", "an aerial photograph", "an underwater photograph",
"a product photograph", "a documentary photograph", "a travel photograph",
"a black and white photograph", "an abstract photograph", "a concert photograph",
"a wedding photograph", "a nature photograph"
]
SCENE_ATTRIBUTES = [
"indoors", "outdoors", "daytime", "nighttime", "urban", "rural",
"beach", "mountains", "forest", "desert", "snowy", "rainy",
"foggy", "sunny", "crowded", "empty", "modern", "vintage",
"colorful", "minimalist"
]
STYLE_ATTRIBUTES = [
"professional", "casual", "artistic", "documentary", "aerial view",
"close-up", "wide-angle", "telephoto", "panoramic", "HDR",
"long exposure", "shallow depth of field", "silhouette", "motion blur"
]
def get_clip_caption(image, model, processor):
try:
content_inputs = processor(text=CONTENT_CATEGORIES, images=image, return_tensors="pt", padding=True)
content_outputs = model(**content_inputs)
content_probs = content_outputs.logits_per_image.softmax(dim=1)[0]
top_content_probs, top_content_indices = torch.topk(content_probs, 2)
scene_inputs = processor(text=SCENE_ATTRIBUTES, images=image, return_tensors="pt", padding=True)
scene_outputs = model(**scene_inputs)
scene_probs = scene_outputs.logits_per_image.softmax(dim=1)[0]
top_scene_probs, top_scene_indices = torch.topk(scene_probs, 2)
style_inputs = processor(text=STYLE_ATTRIBUTES, images=image, return_tensors="pt", padding=True)
style_outputs = model(**style_inputs)
style_probs = style_outputs.logits_per_image.softmax(dim=1)[0]
top_style_probs, top_style_indices = torch.topk(style_probs, 1)
primary_content = CONTENT_CATEGORIES[top_content_indices[0].item()].replace("a ", "")
primary_scene = SCENE_ATTRIBUTES[top_scene_indices[0].item()]
primary_style = STYLE_ATTRIBUTES[top_style_indices[0].item()]
secondary_elements = []
if top_content_probs[1].item() > 0.15:
secondary_content = CONTENT_CATEGORIES[top_content_indices[1].item()].replace("a ", "")
secondary_elements.append(f"with elements of {secondary_content}")
if top_scene_probs[1].item() > 0.15:
secondary_scene = SCENE_ATTRIBUTES[top_scene_indices[1].item()]
secondary_elements.append(f"also showing {secondary_scene} characteristics")
detailed_caption = f"This looks like {CONTENT_CATEGORIES[top_content_indices[0].item()]} in a {primary_scene} setting"
if secondary_elements:
detailed_caption += ", " + " ".join(secondary_elements)
detailed_caption += f". The image has a {primary_style} look."
detailed_caption += f" (Main type: {top_content_probs[0].item()*100:.1f}% sure)"
return detailed_caption
except Exception as e:
return f"CLIP model error: {str(e)}"
# ......................... TRANSLATION FUNCTION .......................
def batch_translate(texts, target_lang):
try:
translator = GoogleTranslator(source='en', target=target_lang)
return {key: translator.translate(value) for key, value in texts.items()}
except Exception as e:
return {key: f"Translation error: {str(e)}" for key in texts}
# .......................... GUI STYLE .............................
def apply_styles():
st.markdown("""
<style>
:root {
--bg-color: #1a1a2e;
--sidebar-bg: #16213e;
--card-bg: #0f3460;
--accent-color: #4ecdc4;
--primary-color: #ff6f61;
--text-light: #f5f5f5;
}
body {
background-color: var(--bg-color);
color: var(--text-light);
}
.big-title {
font-size: 2.8rem;
color: var(--primary-color);
text-align: center;
margin: 2rem 0;
font-weight: 700;
}
.small-title {
font-size: 1.8rem;
color: var(--accent-color);
margin-bottom: .5rem;
font-weight: 600;
}
.info-box {
background-color: var(--sidebar-bg);
padding: .6rem;
border-radius: 15px;
border: 2px solid var(--card-bg);
margin-bottom: 1rem;
font-size: 1.3rem;
}
.stButton>button {
background-color: var(--primary-color);
color: #fff !important;
border-radius: 10px;
padding: .9rem;
font-size: 1.2rem;
font-weight: 700;
transition: background-color 0.3s, transform 0.2s;
}
.stButton>button:hover {
background-color: #ff9e7d;
transform: translateY(-2px);
}
.caption-card, .compare-caption {
background-color: var(--card-bg);
padding: 1rem;
border-radius: 8px;
margin: .5rem 0;
box-shadow: 0 2px 8px rgba(0,0,0,0.3);
font-size: 1.1rem;
}
.caption-card:hover, .tools-item:hover {
border: 1px solid var(--accent-color);
}
.tools-list {
background-color: var(--sidebar-bg);
border-radius: 8px;
padding: .8rem;
margin-top: .5rem;
border: 1px solid var(--card-bg);
}
.tools-item {
margin-bottom: .5rem;
font-size: 1rem;
}
.model-icons {
font-size: 1.3rem;
color: var(--primary-color);
margin-right: .5rem;
}
.stSidebar {
background-color: var(--sidebar-bg);
}
</style>
""", unsafe_allow_html=True)
# ............................ COMPONENTS .......................
def display_sidebar():
with st.sidebar:
st.markdown('<h2 class="small-title">πŸ“˜ About This App</h2>', unsafe_allow_html=True)
st.markdown('<div class="info-box">This application is part of our NLP project, focused on generating captions for images using four pretrained models to generate captions, each offering a different approach and style to describe the image content.</div>', unsafe_allow_html=True)
st.markdown('<h3 class="small-title">πŸ€– AI Models</h3>', unsafe_allow_html=True)
st.markdown('''
<div class="tools-list">
<div class="tools-item"><span class="model-icons">⭐</span><b>BLIP</b>: Detailed descriptions</div>
<div class="tools-item"><span class="model-icons">⭐</span><b>ViT-GPT2</b>: Smooth & concise</div>
<div class="tools-item"><span class="model-icons">⭐</span><b>GIT</b>: Scene understanding</div>
<div class="tools-item"><span class="model-icons">⭐</span><b>CLIP</b>: Image categorization</div>
</div>
''', unsafe_allow_html=True)
st.markdown('<h3 class="small-title">πŸ”§ Tech Used</h3>', unsafe_allow_html=True)
st.markdown('''
<div class="tools-list">
<div class="tools-item">Streamlit</div>
<div class="tools-item">Hugging Face</div>
<div class="tools-item">PyTorch</div>
<div class="tools-item">Google Translator</div>
</div>
''', unsafe_allow_html=True)
with st.expander("πŸ“Š Model Comparison", expanded=False):
st.markdown('''
| Model | Strength | Best Use |
|----------|--------------------|-------------------|
| BLIP | Detailed captions | Deep analysis |
| ViT-GPT2 | Smooth captions | Quick summaries |
| GIT | Scene storytelling | Context insight |
| CLIP | Categorization | Filtering & tags |
''', unsafe_allow_html=True)
def image_input_section():
with st.container():
st.markdown('<h2 class="small-title">πŸŒ„ Image Input</h2>', unsafe_allow_html=True)
with st.container():
st.markdown('<div class="input-area">', unsafe_allow_html=True)
input_option = st.radio("Choose input method:", ["Upload Image", "Image URL"], horizontal=True)
image = None
if input_option == "Upload Image":
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"], label_visibility="collapsed")
if uploaded_file is not None:
try:
image = Image.open(uploaded_file).convert("RGB")
except Exception as e:
st.error(f"Error opening image: {e}")
else:
url = st.text_input("Enter Image URL", placeholder="https://example.com/image.jpg", label_visibility="collapsed")
if url:
try:
response = requests.get(url)
if response.status_code == 200 and 'image' in response.headers.get('Content-Type', ''):
image = Image.open(io.BytesIO(response.content)).convert("RGB")
else:
st.error("Not a valid image URL.")
except Exception as e:
st.error(f"Error loading image from URL: {e}")
st.markdown('</div>', unsafe_allow_html=True)
return image
def model_selection_and_display(image):
if image:
with st.container():
col_models, col_image = st.columns([2, 3])
with col_models:
st.markdown('<h2 class="small-title">βš™οΈ Select Models</h2>', unsafe_allow_html=True)
with st.container():
st.markdown('<div class="model-area">', unsafe_allow_html=True)
use_blip = st.checkbox("BLIP (Bootstrapping Language-Image Pre-training)", value=True)
use_vit_gpt2 = st.checkbox("ViT-GPT2 (ViT-GPT2 combines Vision Transformer)", value=True)
use_git = st.checkbox("GIT (Generative Image-to-text Transformer)", value=True)
use_clip = st.checkbox("CLIP (Contrastive Language-Image Pre-training)", value=True)
with st.expander("TRANSLATION LANGUAGES"):
translation_language = st.selectbox(
"Translation Language",
["Arabic", "French", "Spanish", "Chinese", "Russian", "German"],
index=0
)
language_code_map = {
"Arabic": "ar", "French": "fr", "Spanish": "es",
"Chinese": "zh", "Russian": "ru", "German": "de"
}
selected_lang_code = language_code_map[translation_language]
st.markdown("<br>", unsafe_allow_html=True)
generate_button = st.button("✨ Generate Captions", type="primary")
st.markdown('</div>', unsafe_allow_html=True)
with col_image:
with st.spinner("Processing image..."):
quality_ok, quality_message = check_image_quality(image)
if not quality_ok:
st.warning(quality_message)
processed_image = preprocess_image(image)
st.markdown('<div class="image-box">', unsafe_allow_html=True)
st.image(processed_image, caption="PROCESSED IMAGE ", use_container_width=True)
st.markdown('</div>', unsafe_allow_html=True)
return {
"generate_button": generate_button,
"use_blip": use_blip,
"use_vit_gpt2": use_vit_gpt2,
"use_git": use_git,
"use_clip": use_clip,
"selected_lang_code": selected_lang_code,
"translation_language": translation_language,
"processed_image": processed_image
}
return None
def generate_and_display_captions(config):
if not config or not config["generate_button"]:
return
selected_models = []
if config["use_blip"]: selected_models.append("BLIP")
if config["use_vit_gpt2"]: selected_models.append("ViT-GPT2")
if config["use_git"]: selected_models.append("GIT")
if config["use_clip"]: selected_models.append("CLIP")
if not selected_models:
st.warning("Please CHOOSE at least one model.")
return
with st.spinner("Loading models..."):
models_data = {}
if config["use_blip"]: models_data["BLIP"] = load_blip_model()
if config["use_vit_gpt2"]: models_data["ViT-GPT2"] = load_vit_gpt2_model()
if config["use_git"]: models_data["GIT"] = load_git_model()
if config["use_clip"]: models_data["CLIP"] = load_clip_model()
with st.spinner("Creating captions... Please wait... "):
captions = {}
with ThreadPoolExecutor(max_workers=min(len(selected_models), 4)) as executor:
future_to_model = {
executor.submit(generate_caption, config["processed_image"], model_name, models_data): model_name
for model_name in selected_models
}
for future in future_to_model:
model_name = future_to_model[future]
try:
caption = future.result()
captions[model_name] = caption
except Exception as e:
captions[model_name] = f"Error with caption: {str(e)}"
with st.spinner(f"Translating to {config['translation_language']}..."):
translations = batch_translate(captions, config["selected_lang_code"])
display_captions_in_tabs(captions, translations, config["selected_lang_code"], config["translation_language"])
# Caption Comparison
if len(captions) > 1:
display_caption_comparison(captions)
def display_captions_in_tabs(captions, translations, selected_lang_code, translation_language):
st.markdown('<h2 class="small-title">πŸ“ Generated Captions</h2>', unsafe_allow_html=True)
model_colors = {
"BLIP": "#2d3d7d",
"ViT-GPT2": "#2d3d7d",
"GIT": "#2d3d7d",
"CLIP": "#2d3d7d"
}
tabs = st.tabs([f"{MODEL_CONFIGS[model_name]['icon']} {model_name}" for model_name in captions])
rtl_languages = ["ar"]
text_dir = "rtl" if selected_lang_code in rtl_languages else "ltr"
for i, model_name in enumerate(captions):
with tabs[i]:
st.markdown('<div class="tab-content">', unsafe_allow_html=True)
eng_col, trans_col = st.columns(2)
with eng_col:
st.markdown(f"**πŸ‡¬πŸ‡§ English Caption:**")
st.markdown(f"""
<div class="caption-card" style="background-color: {model_colors[model_name]};">
{captions[model_name]}
</div>
""", unsafe_allow_html=True)
with trans_col:
lang_flags = {
"ar": "πŸ‡ͺπŸ‡¬", "fr": "πŸ‡«πŸ‡·", "es": "πŸ‡ͺπŸ‡Έ",
"zh": "πŸ‡¨πŸ‡³", "ru": "πŸ‡·πŸ‡Ί", "de": "πŸ‡©πŸ‡ͺ"
}
st.markdown(f"**{lang_flags.get(selected_lang_code, '🌐')} {translation_language} Translation:**")
st.markdown(f"""
<div class="caption-card" style="background-color: {model_colors[model_name]};" dir="{text_dir}">
{translations[model_name]}
</div>
""", unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
with st.expander("ℹ️ About this model"):
st.markdown(MODEL_CONFIGS[model_name]["description"])
def display_caption_comparison(captions):
st.markdown('<h2 class="small-title">⭐ Compare All Captions</h2>', unsafe_allow_html=True)
st.markdown('<div class="compare-box">', unsafe_allow_html=True)
model_colors = {
"BLIP": "#2d3d7d",
"ViT-GPT2": "#2d3d7d",
"GIT": "#2d3d7d",
"CLIP": "#2d3d7d"
}
for model_name, caption in captions.items():
st.markdown(f"""
<div class="compare-model-name">
{MODEL_CONFIGS[model_name]['icon']} {model_name}
</div>
<div class="compare-caption" style="background-color: {model_colors[model_name]};">
{caption}
</div>
""", unsafe_allow_html=True)
st.markdown('</div>', unsafe_allow_html=True)
# ................................... MAIN APPLICATION ............................
def main():
# Apply CSS
apply_styles()
# Display title and info
st.markdown('<h1 class="big-title">🌌 AI Image Caption Generator</h1>', unsafe_allow_html=True)
st.markdown('<div class="info-box">Generate, translate, and compare image captions easily.</div>', unsafe_allow_html=True)
# Sidebar
display_sidebar()
# Image input
image = image_input_section()
# Model selection and image display
if image:
config = model_selection_and_display(image)
# Generate and display captions
if config:
generate_and_display_captions(config)
if __name__ == "__main__":
main()