s2049 commited on
Commit
137df29
·
verified ·
1 Parent(s): 321d937

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ from transformers import pipeline, AutoModel, AutoProcessor
4
+ import torch
5
+ import os
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ # Initialize models
10
+ summarizer = pipeline("summarization", "csebuetnlp/mT5_multilingual_XLSum")
11
+ translator_ar2en = pipeline("translation_ar_to_en", "Helsinki-NLP/opus-mt-ar-en")
12
+ clip_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
13
+ clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
+
15
+ # Image preprocessing
16
+ def precompute_embeddings(image_dir="images"):
17
+ image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)
18
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
19
+
20
+ embeddings = []
21
+ for path in image_paths:
22
+ image = Image.open(path)
23
+ inputs = clip_processor(images=image, return_tensors="pt")
24
+ with torch.no_grad():
25
+ embeddings.append(clip_model.get_image_features(**inputs))
26
+ return image_paths, torch.cat(embeddings)
27
+
28
+ image_paths, image_embeddings = precompute_embeddings()
29
+
30
+ def process(input_text, language):
31
+ # Text summarization
32
+ summary = summarizer(input_text, max_length=150, min_length=30)[0]['summary_text']
33
+
34
+ # Translation if Arabic
35
+ if language == "Arabic":
36
+ translated = translator_ar2en(summary)[0]['translation_text']
37
+ query_text = translated
38
+ else:
39
+ query_text = summary
40
+
41
+ # Text-image retrieval
42
+ text_inputs = clip_processor(
43
+ text=query_text,
44
+ return_tensors="pt",
45
+ padding=True,
46
+ truncation=True
47
+ )
48
+ with torch.no_grad():
49
+ text_emb = clip_model.get_text_features(**text_inputs)
50
+
51
+ similarities = (text_emb @ image_embeddings.T).softmax(dim=-1)
52
+ top_indices = similarities.topk(3).indices.numpy()
53
+ results = [image_paths[i] for i in top_indices]
54
+
55
+ return summary, translated if language == "Arabic" else "", results
56
+
57
+ # Gradio interface
58
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
59
+ gr.Markdown("# 🌍 Multi-Task AI: Summarization & Image Retrieval")
60
+
61
+ with gr.Row():
62
+ lang = gr.Dropdown(["English", "Arabic"], label="Input Language")
63
+ text_input = gr.Textbox(label="Input Text", lines=5)
64
+
65
+ with gr.Row():
66
+ summary_out = gr.Textbox(label="Summary")
67
+ trans_out = gr.Textbox(label="English Query Text", visible=False)
68
+
69
+ gallery = gr.Gallery(label="Retrieved Images", columns=3)
70
+ submit = gr.Button("Process", variant="primary")
71
+
72
+ def toggle_translation(lang):
73
+ return gr.update(visible=lang == "Arabic")
74
+
75
+ lang.change(toggle_translation, lang, trans_out)
76
+ submit.click(process, [text_input, lang], [summary_out, trans_out, gallery])
77
+
78
+ demo.launch()