Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
import gradio as gr
|
3 |
+
from transformers import pipeline, AutoModel, AutoProcessor
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
# Initialize models
|
10 |
+
summarizer = pipeline("summarization", "csebuetnlp/mT5_multilingual_XLSum")
|
11 |
+
translator_ar2en = pipeline("translation_ar_to_en", "Helsinki-NLP/opus-mt-ar-en")
|
12 |
+
clip_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
|
13 |
+
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
14 |
+
|
15 |
+
# Image preprocessing
|
16 |
+
def precompute_embeddings(image_dir="images"):
|
17 |
+
image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)
|
18 |
+
if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
19 |
+
|
20 |
+
embeddings = []
|
21 |
+
for path in image_paths:
|
22 |
+
image = Image.open(path)
|
23 |
+
inputs = clip_processor(images=image, return_tensors="pt")
|
24 |
+
with torch.no_grad():
|
25 |
+
embeddings.append(clip_model.get_image_features(**inputs))
|
26 |
+
return image_paths, torch.cat(embeddings)
|
27 |
+
|
28 |
+
image_paths, image_embeddings = precompute_embeddings()
|
29 |
+
|
30 |
+
def process(input_text, language):
|
31 |
+
# Text summarization
|
32 |
+
summary = summarizer(input_text, max_length=150, min_length=30)[0]['summary_text']
|
33 |
+
|
34 |
+
# Translation if Arabic
|
35 |
+
if language == "Arabic":
|
36 |
+
translated = translator_ar2en(summary)[0]['translation_text']
|
37 |
+
query_text = translated
|
38 |
+
else:
|
39 |
+
query_text = summary
|
40 |
+
|
41 |
+
# Text-image retrieval
|
42 |
+
text_inputs = clip_processor(
|
43 |
+
text=query_text,
|
44 |
+
return_tensors="pt",
|
45 |
+
padding=True,
|
46 |
+
truncation=True
|
47 |
+
)
|
48 |
+
with torch.no_grad():
|
49 |
+
text_emb = clip_model.get_text_features(**text_inputs)
|
50 |
+
|
51 |
+
similarities = (text_emb @ image_embeddings.T).softmax(dim=-1)
|
52 |
+
top_indices = similarities.topk(3).indices.numpy()
|
53 |
+
results = [image_paths[i] for i in top_indices]
|
54 |
+
|
55 |
+
return summary, translated if language == "Arabic" else "", results
|
56 |
+
|
57 |
+
# Gradio interface
|
58 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
59 |
+
gr.Markdown("# 🌍 Multi-Task AI: Summarization & Image Retrieval")
|
60 |
+
|
61 |
+
with gr.Row():
|
62 |
+
lang = gr.Dropdown(["English", "Arabic"], label="Input Language")
|
63 |
+
text_input = gr.Textbox(label="Input Text", lines=5)
|
64 |
+
|
65 |
+
with gr.Row():
|
66 |
+
summary_out = gr.Textbox(label="Summary")
|
67 |
+
trans_out = gr.Textbox(label="English Query Text", visible=False)
|
68 |
+
|
69 |
+
gallery = gr.Gallery(label="Retrieved Images", columns=3)
|
70 |
+
submit = gr.Button("Process", variant="primary")
|
71 |
+
|
72 |
+
def toggle_translation(lang):
|
73 |
+
return gr.update(visible=lang == "Arabic")
|
74 |
+
|
75 |
+
lang.change(toggle_translation, lang, trans_out)
|
76 |
+
submit.click(process, [text_input, lang], [summary_out, trans_out, gallery])
|
77 |
+
|
78 |
+
demo.launch()
|