|
import os |
|
|
|
|
|
os.environ['TRANSFORMERS_CACHE'] = '/tmp/hf' |
|
os.environ['HF_HOME'] = '/tmp/hf' |
|
os.environ['XDG_CACHE_HOME'] = '/tmp' |
|
os.environ['STREAMLIT_HOME'] = '/tmp' |
|
os.makedirs('/tmp/hf', exist_ok=True) |
|
import streamlit as st |
|
import torch |
|
import numpy as np |
|
from PIL import Image, ImageEnhance |
|
import io |
|
import requests |
|
from transformers import ( |
|
BlipForConditionalGeneration, |
|
BlipProcessor, |
|
VisionEncoderDecoderModel, |
|
ViTImageProcessor, |
|
AutoTokenizer, |
|
CLIPProcessor, |
|
CLIPModel, |
|
AutoModelForCausalLM, |
|
AutoProcessor |
|
) |
|
from deep_translator import GoogleTranslator |
|
from scipy.ndimage import variance |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
|
|
st.set_page_config( |
|
page_title="πΌοΈ AI Image Caption Generator", |
|
layout="wide", |
|
initial_sidebar_state="expanded" |
|
) |
|
|
|
|
|
MODEL_CONFIGS = { |
|
"BLIP": { |
|
"name": "BLIP", |
|
"icon": "β", |
|
"description": "BLIP (Bootstrapping Language-Image Pre-training) is designed to learn vision-language representation from noisy web data. It excels at generating detailed and accurate image descriptions.", |
|
"generate_params": {"max_length": 50, "num_beams": 5, "min_length": 10, "do_sample": True, "top_p": 0.9, "repetition_penalty": 1.5} |
|
}, |
|
"ViT-GPT2": { |
|
"name": "ViT-GPT2", |
|
"icon": "β", |
|
"description": "ViT-GPT2 combines Vision Transformer for image encoding with GPT2 for text generation. It's effective at capturing visual details and creating fluent natural language descriptions.", |
|
"generate_params": {"max_length": 50, "num_beams": 5, "min_length": 10, "repetition_penalty": 1.5} |
|
}, |
|
"GIT": { |
|
"name": "GIT-base", |
|
"icon": "β", |
|
"description": "GIT (Generative Image-to-text Transformer) is designed specifically for image captioning tasks, focusing on generating coherent and contextually relevant descriptions.", |
|
"generate_params": {"max_length": 50, "num_beams": 4, "min_length": 8, "repetition_penalty": 1.5} |
|
}, |
|
"CLIP": { |
|
"name": "CLIP", |
|
"icon": "β", |
|
"description": "CLIP (Contrastive Language-Image Pre-training) analyzes images across multiple dimensions including content type, scene attributes, and photographic style.", |
|
} |
|
} |
|
|
|
|
|
@st.cache_resource |
|
def load_blip_model(): |
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
return model, processor |
|
|
|
@st.cache_resource |
|
def load_vit_gpt2_model(): |
|
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") |
|
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") |
|
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") |
|
return model, feature_extractor, tokenizer |
|
|
|
@st.cache_resource |
|
def load_git_model(): |
|
processor = AutoProcessor.from_pretrained("microsoft/git-base") |
|
model = AutoModelForCausalLM.from_pretrained("microsoft/git-base") |
|
return model, processor |
|
|
|
@st.cache_resource |
|
def load_clip_model(): |
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
return model, processor |
|
|
|
|
|
def preprocess_image(image): |
|
max_size = 1024 |
|
if max(image.size) > max_size: |
|
ratio = max_size / max(image.size) |
|
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) |
|
image = image.resize(new_size, Image.LANCZOS) |
|
enhancer = ImageEnhance.Contrast(image) |
|
image = enhancer.enhance(1.2) |
|
img_array = np.array(image.convert('L')) |
|
if np.mean(img_array) < 100: |
|
brightness_enhancer = ImageEnhance.Brightness(image) |
|
image = brightness_enhancer.enhance(1.3) |
|
return image |
|
|
|
def check_image_quality(image): |
|
if image.width < 200 or image.height < 200: |
|
return False, "Image is too small. Please use a bigger image." |
|
img_array = np.array(image.convert('L')) |
|
if variance(img_array) < 100: |
|
return False, "Image might be too blurry. Please use a clearer image." |
|
return True, "Image looks good for captioning." |
|
|
|
|
|
def generate_caption(image, model_name, models_data): |
|
if model_name == "BLIP": |
|
model, processor = models_data[model_name] |
|
return get_blip_caption(image, model, processor) |
|
elif model_name == "ViT-GPT2": |
|
model, feature_extractor, tokenizer = models_data[model_name] |
|
return get_vit_gpt2_caption(image, model, feature_extractor, tokenizer) |
|
elif model_name == "GIT": |
|
model, processor = models_data[model_name] |
|
return get_git_caption(image, model, processor) |
|
elif model_name == "CLIP": |
|
model, processor = models_data[model_name] |
|
return get_clip_caption(image, model, processor) |
|
return "Model not supported" |
|
|
|
def get_blip_caption(image, model, processor): |
|
try: |
|
inputs = processor(images=image, return_tensors="pt", padding=True, truncation=True) |
|
output = model.generate(**inputs, **MODEL_CONFIGS["BLIP"]["generate_params"]) |
|
caption = processor.decode(output[0], skip_special_tokens=True) |
|
return caption |
|
except Exception as e: |
|
return f"BLIP model error: {str(e)}" |
|
|
|
def get_vit_gpt2_caption(image, model, feature_extractor, tokenizer): |
|
try: |
|
inputs = feature_extractor(images=image, return_tensors="pt", padding=True) |
|
output = model.generate( |
|
pixel_values=inputs.pixel_values, |
|
**MODEL_CONFIGS["ViT-GPT2"]["generate_params"], |
|
attention_mask=inputs.attention_mask if hasattr(inputs, "attention_mask") else None |
|
) |
|
caption = tokenizer.decode(output[0], skip_special_tokens=True) |
|
return caption |
|
except Exception as e: |
|
return f"ViT-GPT2 model error: {str(e)}" |
|
|
|
def get_git_caption(image, model, processor): |
|
try: |
|
inputs = processor(images=image, return_tensors="pt", padding=True) |
|
output = model.generate(**inputs, **MODEL_CONFIGS["GIT"]["generate_params"]) |
|
caption = processor.decode(output[0], skip_special_tokens=True) |
|
return caption |
|
except Exception as e: |
|
return f"GIT model error: {str(e)}" |
|
|
|
|
|
CONTENT_CATEGORIES = [ |
|
"a portrait photograph", "a landscape photograph", "a wildlife photograph", |
|
"an architectural photograph", "a street photograph", "a food photograph", |
|
"a fashion photograph", "a sports photograph", "a macro photograph", |
|
"a night photograph", "an aerial photograph", "an underwater photograph", |
|
"a product photograph", "a documentary photograph", "a travel photograph", |
|
"a black and white photograph", "an abstract photograph", "a concert photograph", |
|
"a wedding photograph", "a nature photograph" |
|
] |
|
|
|
SCENE_ATTRIBUTES = [ |
|
"indoors", "outdoors", "daytime", "nighttime", "urban", "rural", |
|
"beach", "mountains", "forest", "desert", "snowy", "rainy", |
|
"foggy", "sunny", "crowded", "empty", "modern", "vintage", |
|
"colorful", "minimalist" |
|
] |
|
|
|
STYLE_ATTRIBUTES = [ |
|
"professional", "casual", "artistic", "documentary", "aerial view", |
|
"close-up", "wide-angle", "telephoto", "panoramic", "HDR", |
|
"long exposure", "shallow depth of field", "silhouette", "motion blur" |
|
] |
|
|
|
def get_clip_caption(image, model, processor): |
|
try: |
|
content_inputs = processor(text=CONTENT_CATEGORIES, images=image, return_tensors="pt", padding=True) |
|
content_outputs = model(**content_inputs) |
|
content_probs = content_outputs.logits_per_image.softmax(dim=1)[0] |
|
top_content_probs, top_content_indices = torch.topk(content_probs, 2) |
|
|
|
scene_inputs = processor(text=SCENE_ATTRIBUTES, images=image, return_tensors="pt", padding=True) |
|
scene_outputs = model(**scene_inputs) |
|
scene_probs = scene_outputs.logits_per_image.softmax(dim=1)[0] |
|
top_scene_probs, top_scene_indices = torch.topk(scene_probs, 2) |
|
|
|
style_inputs = processor(text=STYLE_ATTRIBUTES, images=image, return_tensors="pt", padding=True) |
|
style_outputs = model(**style_inputs) |
|
style_probs = style_outputs.logits_per_image.softmax(dim=1)[0] |
|
top_style_probs, top_style_indices = torch.topk(style_probs, 1) |
|
|
|
primary_content = CONTENT_CATEGORIES[top_content_indices[0].item()].replace("a ", "") |
|
primary_scene = SCENE_ATTRIBUTES[top_scene_indices[0].item()] |
|
primary_style = STYLE_ATTRIBUTES[top_style_indices[0].item()] |
|
|
|
secondary_elements = [] |
|
if top_content_probs[1].item() > 0.15: |
|
secondary_content = CONTENT_CATEGORIES[top_content_indices[1].item()].replace("a ", "") |
|
secondary_elements.append(f"with elements of {secondary_content}") |
|
if top_scene_probs[1].item() > 0.15: |
|
secondary_scene = SCENE_ATTRIBUTES[top_scene_indices[1].item()] |
|
secondary_elements.append(f"also showing {secondary_scene} characteristics") |
|
|
|
detailed_caption = f"This looks like {CONTENT_CATEGORIES[top_content_indices[0].item()]} in a {primary_scene} setting" |
|
if secondary_elements: |
|
detailed_caption += ", " + " ".join(secondary_elements) |
|
detailed_caption += f". The image has a {primary_style} look." |
|
detailed_caption += f" (Main type: {top_content_probs[0].item()*100:.1f}% sure)" |
|
return detailed_caption |
|
except Exception as e: |
|
return f"CLIP model error: {str(e)}" |
|
|
|
|
|
def batch_translate(texts, target_lang): |
|
try: |
|
translator = GoogleTranslator(source='en', target=target_lang) |
|
return {key: translator.translate(value) for key, value in texts.items()} |
|
except Exception as e: |
|
return {key: f"Translation error: {str(e)}" for key in texts} |
|
|
|
|
|
def apply_styles(): |
|
st.markdown(""" |
|
<style> |
|
:root { |
|
--bg-color: #1a1a2e; |
|
--sidebar-bg: #16213e; |
|
--card-bg: #0f3460; |
|
--accent-color: #4ecdc4; |
|
--primary-color: #ff6f61; |
|
--text-light: #f5f5f5; |
|
} |
|
|
|
body { |
|
background-color: var(--bg-color); |
|
color: var(--text-light); |
|
} |
|
|
|
.big-title { |
|
font-size: 2.8rem; |
|
color: var(--primary-color); |
|
text-align: center; |
|
margin: 2rem 0; |
|
font-weight: 700; |
|
} |
|
|
|
.small-title { |
|
font-size: 1.8rem; |
|
color: var(--accent-color); |
|
margin-bottom: .5rem; |
|
font-weight: 600; |
|
} |
|
|
|
.info-box { |
|
background-color: var(--sidebar-bg); |
|
padding: .6rem; |
|
border-radius: 15px; |
|
border: 2px solid var(--card-bg); |
|
margin-bottom: 1rem; |
|
font-size: 1.3rem; |
|
} |
|
|
|
.stButton>button { |
|
background-color: var(--primary-color); |
|
color: #fff !important; |
|
border-radius: 10px; |
|
padding: .9rem; |
|
font-size: 1.2rem; |
|
font-weight: 700; |
|
transition: background-color 0.3s, transform 0.2s; |
|
} |
|
.stButton>button:hover { |
|
background-color: #ff9e7d; |
|
transform: translateY(-2px); |
|
} |
|
|
|
.caption-card, .compare-caption { |
|
background-color: var(--card-bg); |
|
padding: 1rem; |
|
border-radius: 8px; |
|
margin: .5rem 0; |
|
box-shadow: 0 2px 8px rgba(0,0,0,0.3); |
|
font-size: 1.1rem; |
|
} |
|
|
|
.caption-card:hover, .tools-item:hover { |
|
border: 1px solid var(--accent-color); |
|
} |
|
|
|
.tools-list { |
|
background-color: var(--sidebar-bg); |
|
border-radius: 8px; |
|
padding: .8rem; |
|
margin-top: .5rem; |
|
border: 1px solid var(--card-bg); |
|
} |
|
|
|
.tools-item { |
|
margin-bottom: .5rem; |
|
font-size: 1rem; |
|
} |
|
|
|
.model-icons { |
|
font-size: 1.3rem; |
|
color: var(--primary-color); |
|
margin-right: .5rem; |
|
} |
|
|
|
.stSidebar { |
|
background-color: var(--sidebar-bg); |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
def display_sidebar(): |
|
with st.sidebar: |
|
st.markdown('<h2 class="small-title">π About This App</h2>', unsafe_allow_html=True) |
|
st.markdown('<div class="info-box">This application is part of our NLP project, focused on generating captions for images using four pretrained models to generate captions, each offering a different approach and style to describe the image content.</div>', unsafe_allow_html=True) |
|
|
|
st.markdown('<h3 class="small-title">π€ AI Models</h3>', unsafe_allow_html=True) |
|
st.markdown(''' |
|
<div class="tools-list"> |
|
<div class="tools-item"><span class="model-icons">β</span><b>BLIP</b>: Detailed descriptions</div> |
|
<div class="tools-item"><span class="model-icons">β</span><b>ViT-GPT2</b>: Smooth & concise</div> |
|
<div class="tools-item"><span class="model-icons">β</span><b>GIT</b>: Scene understanding</div> |
|
<div class="tools-item"><span class="model-icons">β</span><b>CLIP</b>: Image categorization</div> |
|
</div> |
|
''', unsafe_allow_html=True) |
|
|
|
st.markdown('<h3 class="small-title">π§ Tech Used</h3>', unsafe_allow_html=True) |
|
st.markdown(''' |
|
<div class="tools-list"> |
|
<div class="tools-item">Streamlit</div> |
|
<div class="tools-item">Hugging Face</div> |
|
<div class="tools-item">PyTorch</div> |
|
<div class="tools-item">Google Translator</div> |
|
</div> |
|
''', unsafe_allow_html=True) |
|
|
|
with st.expander("π Model Comparison", expanded=False): |
|
st.markdown(''' |
|
| Model | Strength | Best Use | |
|
|----------|--------------------|-------------------| |
|
| BLIP | Detailed captions | Deep analysis | |
|
| ViT-GPT2 | Smooth captions | Quick summaries | |
|
| GIT | Scene storytelling | Context insight | |
|
| CLIP | Categorization | Filtering & tags | |
|
''', unsafe_allow_html=True) |
|
|
|
def image_input_section(): |
|
with st.container(): |
|
st.markdown('<h2 class="small-title">π Image Input</h2>', unsafe_allow_html=True) |
|
with st.container(): |
|
st.markdown('<div class="input-area">', unsafe_allow_html=True) |
|
input_option = st.radio("Choose input method:", ["Upload Image", "Image URL"], horizontal=True) |
|
image = None |
|
|
|
if input_option == "Upload Image": |
|
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"], label_visibility="collapsed") |
|
if uploaded_file is not None: |
|
try: |
|
image = Image.open(uploaded_file).convert("RGB") |
|
except Exception as e: |
|
st.error(f"Error opening image: {e}") |
|
else: |
|
url = st.text_input("Enter Image URL", placeholder="https://example.com/image.jpg", label_visibility="collapsed") |
|
if url: |
|
try: |
|
response = requests.get(url) |
|
if response.status_code == 200 and 'image' in response.headers.get('Content-Type', ''): |
|
image = Image.open(io.BytesIO(response.content)).convert("RGB") |
|
else: |
|
st.error("Not a valid image URL.") |
|
except Exception as e: |
|
st.error(f"Error loading image from URL: {e}") |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
return image |
|
|
|
def model_selection_and_display(image): |
|
if image: |
|
with st.container(): |
|
col_models, col_image = st.columns([2, 3]) |
|
|
|
with col_models: |
|
st.markdown('<h2 class="small-title">βοΈ Select Models</h2>', unsafe_allow_html=True) |
|
with st.container(): |
|
st.markdown('<div class="model-area">', unsafe_allow_html=True) |
|
use_blip = st.checkbox("BLIP (Bootstrapping Language-Image Pre-training)", value=True) |
|
use_vit_gpt2 = st.checkbox("ViT-GPT2 (ViT-GPT2 combines Vision Transformer)", value=True) |
|
use_git = st.checkbox("GIT (Generative Image-to-text Transformer)", value=True) |
|
use_clip = st.checkbox("CLIP (Contrastive Language-Image Pre-training)", value=True) |
|
|
|
with st.expander("TRANSLATION LANGUAGES"): |
|
translation_language = st.selectbox( |
|
"Translation Language", |
|
["Arabic", "French", "Spanish", "Chinese", "Russian", "German"], |
|
index=0 |
|
) |
|
language_code_map = { |
|
"Arabic": "ar", "French": "fr", "Spanish": "es", |
|
"Chinese": "zh", "Russian": "ru", "German": "de" |
|
} |
|
selected_lang_code = language_code_map[translation_language] |
|
|
|
st.markdown("<br>", unsafe_allow_html=True) |
|
generate_button = st.button("β¨ Generate Captions", type="primary") |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
with col_image: |
|
with st.spinner("Processing image..."): |
|
quality_ok, quality_message = check_image_quality(image) |
|
if not quality_ok: |
|
st.warning(quality_message) |
|
processed_image = preprocess_image(image) |
|
st.markdown('<div class="image-box">', unsafe_allow_html=True) |
|
st.image(processed_image, caption="PROCESSED IMAGE ", use_container_width=True) |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
return { |
|
"generate_button": generate_button, |
|
"use_blip": use_blip, |
|
"use_vit_gpt2": use_vit_gpt2, |
|
"use_git": use_git, |
|
"use_clip": use_clip, |
|
"selected_lang_code": selected_lang_code, |
|
"translation_language": translation_language, |
|
"processed_image": processed_image |
|
} |
|
return None |
|
|
|
def generate_and_display_captions(config): |
|
if not config or not config["generate_button"]: |
|
return |
|
|
|
selected_models = [] |
|
if config["use_blip"]: selected_models.append("BLIP") |
|
if config["use_vit_gpt2"]: selected_models.append("ViT-GPT2") |
|
if config["use_git"]: selected_models.append("GIT") |
|
if config["use_clip"]: selected_models.append("CLIP") |
|
|
|
if not selected_models: |
|
st.warning("Please CHOOSE at least one model.") |
|
return |
|
|
|
with st.spinner("Loading models..."): |
|
models_data = {} |
|
if config["use_blip"]: models_data["BLIP"] = load_blip_model() |
|
if config["use_vit_gpt2"]: models_data["ViT-GPT2"] = load_vit_gpt2_model() |
|
if config["use_git"]: models_data["GIT"] = load_git_model() |
|
if config["use_clip"]: models_data["CLIP"] = load_clip_model() |
|
|
|
with st.spinner("Creating captions... Please wait... "): |
|
captions = {} |
|
with ThreadPoolExecutor(max_workers=min(len(selected_models), 4)) as executor: |
|
future_to_model = { |
|
executor.submit(generate_caption, config["processed_image"], model_name, models_data): model_name |
|
for model_name in selected_models |
|
} |
|
for future in future_to_model: |
|
model_name = future_to_model[future] |
|
try: |
|
caption = future.result() |
|
captions[model_name] = caption |
|
except Exception as e: |
|
captions[model_name] = f"Error with caption: {str(e)}" |
|
|
|
with st.spinner(f"Translating to {config['translation_language']}..."): |
|
translations = batch_translate(captions, config["selected_lang_code"]) |
|
|
|
display_captions_in_tabs(captions, translations, config["selected_lang_code"], config["translation_language"]) |
|
|
|
|
|
if len(captions) > 1: |
|
display_caption_comparison(captions) |
|
|
|
def display_captions_in_tabs(captions, translations, selected_lang_code, translation_language): |
|
st.markdown('<h2 class="small-title">π Generated Captions</h2>', unsafe_allow_html=True) |
|
|
|
model_colors = { |
|
"BLIP": "#2d3d7d", |
|
"ViT-GPT2": "#2d3d7d", |
|
"GIT": "#2d3d7d", |
|
"CLIP": "#2d3d7d" |
|
} |
|
|
|
tabs = st.tabs([f"{MODEL_CONFIGS[model_name]['icon']} {model_name}" for model_name in captions]) |
|
rtl_languages = ["ar"] |
|
text_dir = "rtl" if selected_lang_code in rtl_languages else "ltr" |
|
|
|
for i, model_name in enumerate(captions): |
|
with tabs[i]: |
|
st.markdown('<div class="tab-content">', unsafe_allow_html=True) |
|
eng_col, trans_col = st.columns(2) |
|
with eng_col: |
|
st.markdown(f"**π¬π§ English Caption:**") |
|
st.markdown(f""" |
|
<div class="caption-card" style="background-color: {model_colors[model_name]};"> |
|
{captions[model_name]} |
|
</div> |
|
""", unsafe_allow_html=True) |
|
with trans_col: |
|
lang_flags = { |
|
"ar": "πͺπ¬", "fr": "π«π·", "es": "πͺπΈ", |
|
"zh": "π¨π³", "ru": "π·πΊ", "de": "π©πͺ" |
|
} |
|
st.markdown(f"**{lang_flags.get(selected_lang_code, 'π')} {translation_language} Translation:**") |
|
st.markdown(f""" |
|
<div class="caption-card" style="background-color: {model_colors[model_name]};" dir="{text_dir}"> |
|
{translations[model_name]} |
|
</div> |
|
""", unsafe_allow_html=True) |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
with st.expander("βΉοΈ About this model"): |
|
st.markdown(MODEL_CONFIGS[model_name]["description"]) |
|
|
|
def display_caption_comparison(captions): |
|
st.markdown('<h2 class="small-title">β Compare All Captions</h2>', unsafe_allow_html=True) |
|
st.markdown('<div class="compare-box">', unsafe_allow_html=True) |
|
|
|
model_colors = { |
|
"BLIP": "#2d3d7d", |
|
"ViT-GPT2": "#2d3d7d", |
|
"GIT": "#2d3d7d", |
|
"CLIP": "#2d3d7d" |
|
} |
|
|
|
for model_name, caption in captions.items(): |
|
st.markdown(f""" |
|
<div class="compare-model-name"> |
|
{MODEL_CONFIGS[model_name]['icon']} {model_name} |
|
</div> |
|
<div class="compare-caption" style="background-color: {model_colors[model_name]};"> |
|
{caption} |
|
</div> |
|
""", unsafe_allow_html=True) |
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
def main(): |
|
|
|
apply_styles() |
|
|
|
|
|
st.markdown('<h1 class="big-title">π AI Image Caption Generator</h1>', unsafe_allow_html=True) |
|
st.markdown('<div class="info-box">Generate, translate, and compare image captions easily.</div>', unsafe_allow_html=True) |
|
|
|
|
|
display_sidebar() |
|
|
|
|
|
image = image_input_section() |
|
|
|
|
|
if image: |
|
config = model_selection_and_display(image) |
|
|
|
|
|
if config: |
|
generate_and_display_captions(config) |
|
|
|
if __name__ == "__main__": |
|
main() |