ayajoharji commited on
Commit
f148e92
·
verified ·
1 Parent(s): 3951de9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -69
app.py CHANGED
@@ -13,26 +13,37 @@ from PIL import Image, ImageDraw
13
  import requests
14
  from io import BytesIO
15
 
 
 
 
 
 
 
 
 
16
  # Download example images
17
  def download_example_images():
18
  image_urls = [
19
  # URL format: ("Image Description", "Image URL")
20
- ("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470"),
21
- ("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9"),
22
- ("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b"),
23
- ("Beach and Ocean", "https://images.unsplash.com/photo-1507525428034-b723cf961d3e"),
24
- ("Desert Dunes", "https://images.unsplash.com/photo-1501594907352-04cda38ebc29"),
25
  ]
26
 
27
  example_images = []
28
  for idx, (description, url) in enumerate(image_urls, start=1):
29
- response = requests.get(url)
30
- if response.status_code == 200:
31
- img = Image.open(BytesIO(response.content))
32
- img.save(f'example{idx}.jpg')
33
- example_images.append([f'example{idx}.jpg'])
34
- else:
35
- print(f"Failed to download image from {url}")
 
 
 
36
  return example_images
37
 
38
  # Download example images and prepare examples list
@@ -44,7 +55,7 @@ def load_image(image):
44
  image_np = np.array(image.convert('RGB'))
45
 
46
  # Resize the image for better processing
47
- resized_image = image.resize((300, 300), resample=Image.LANCZOS)
48
  resized_image_np = np.array(resized_image)
49
 
50
  return resized_image_np
@@ -58,7 +69,7 @@ def extract_colors(image, k=8):
58
  # Ensure data type is float64
59
  pixels = pixels.astype(np.float64)
60
  # Apply K-means clustering to find dominant colors
61
- kmeans = KMeans(n_clusters=k, random_state=0, n_init=10, max_iter=300)
62
  kmeans.fit(pixels)
63
  # Convert normalized colors back to 0-255 scale
64
  colors = (kmeans.cluster_centers_ * 255).astype(int)
@@ -67,15 +78,15 @@ def extract_colors(image, k=8):
67
  # Create an Image for the Color Palette
68
  def create_palette_image(colors):
69
  num_colors = len(colors)
70
- palette_height = 100
71
- palette_width = 100 * num_colors
72
  palette_image = Image.new("RGB", (palette_width, palette_height))
73
 
74
  draw = ImageDraw.Draw(palette_image)
75
  for i, color in enumerate(colors):
76
  # Ensure color values are within the valid range and integers
77
  color = tuple(np.clip(color, 0, 255).astype(int))
78
- draw.rectangle([i * 100, 0, (i + 1) * 100, palette_height], fill=color)
79
 
80
  return palette_image
81
 
@@ -91,68 +102,58 @@ def display_palette(colors):
91
 
92
  # Generate Image Caption Using Hugging Face BLIP
93
  def generate_caption(image):
94
- # Load models only once
95
- if 'processor' not in generate_caption.__dict__:
96
- generate_caption.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
97
- generate_caption.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
98
- processor = generate_caption.processor
99
- model = generate_caption.model
100
-
101
- inputs = processor(images=image, return_tensors="pt")
102
- output = model.generate(**inputs)
103
- caption = processor.decode(output[0], skip_special_tokens=True)
104
  return caption
105
 
106
  # Translate Caption to Arabic Using mBART
107
  def translate_to_arabic(text):
108
- # Load models only once
109
- if 'tokenizer' not in translate_to_arabic.__dict__:
110
- translate_to_arabic.tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
111
- translate_to_arabic.model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
112
- tokenizer = translate_to_arabic.tokenizer
113
- model = translate_to_arabic.model
114
-
115
- tokenizer.src_lang = "en_XX"
116
- encoded = tokenizer(text, return_tensors="pt")
117
- generated_tokens = model.generate(
118
  **encoded,
119
- forced_bos_token_id=tokenizer.lang_code_to_id["ar_AR"]
120
  )
121
- translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
122
  return translated_text
123
 
124
  # Gradio Interface Function (Combining Elements)
125
  def process_image(image):
126
- # Ensure input is a PIL Image
127
- if isinstance(image, np.ndarray):
128
- image = Image.fromarray(image)
129
-
130
- # Convert to RGB format for PIL processing
131
- image_rgb = image.convert("RGB")
132
-
133
- # Load and resize the entire image
134
- resized_image_np = load_image(image_rgb)
135
-
136
- # Convert resized image to PIL Image for Gradio output
137
- resized_image_pil = Image.fromarray(resized_image_np)
138
-
139
- # Generate caption using BLIP model
140
- caption = generate_caption(image_rgb)
141
-
142
- # Translate caption to Arabic
143
- caption_arabic = translate_to_arabic(caption)
144
-
145
- # Extract dominant colors from the entire image
146
- colors = extract_colors(resized_image_np, k=8)
147
- color_palette = display_palette(colors)
148
-
149
- # Create palette image
150
- palette_image = create_palette_image(colors)
151
-
152
- # Combine English and Arabic captions
153
- bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
154
-
155
- return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil
 
 
 
 
156
 
157
  # Create Gradio Interface using Blocks and add a submit button
158
  with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo:
 
13
  import requests
14
  from io import BytesIO
15
 
16
+ # Load models globally at startup
17
+ print("Loading models...")
18
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
19
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
20
+ mbart_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
21
+ mbart_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
22
+ print("Models loaded successfully.")
23
+
24
  # Download example images
25
  def download_example_images():
26
  image_urls = [
27
  # URL format: ("Image Description", "Image URL")
28
+ ("Sunset over Mountains", "https://images.unsplash.com/photo-1501785888041-af3ef285b470?w=512"),
29
+ ("Forest Path", "https://images.unsplash.com/photo-1502082553048-f009c37129b9?w=512"),
30
+ ("City Skyline", "https://images.unsplash.com/photo-1498598453737-8913e843c47b?w=512"),
31
+ ("Beach and Ocean", "https://images.unsplash.com/photo-1507525428034-b723cf961d3e?w=512"),
32
+ ("Desert Dunes", "https://images.unsplash.com/photo-1501594907352-04cda38ebc29?w=512"),
33
  ]
34
 
35
  example_images = []
36
  for idx, (description, url) in enumerate(image_urls, start=1):
37
+ try:
38
+ response = requests.get(url)
39
+ if response.status_code == 200:
40
+ img = Image.open(BytesIO(response.content))
41
+ img.save(f'example{idx}.jpg')
42
+ example_images.append([f'example{idx}.jpg'])
43
+ else:
44
+ print(f"Failed to download image from {url}")
45
+ except Exception as e:
46
+ print(f"Exception occurred while downloading image: {e}")
47
  return example_images
48
 
49
  # Download example images and prepare examples list
 
55
  image_np = np.array(image.convert('RGB'))
56
 
57
  # Resize the image for better processing
58
+ resized_image = image.resize((224, 224), resample=Image.LANCZOS)
59
  resized_image_np = np.array(resized_image)
60
 
61
  return resized_image_np
 
69
  # Ensure data type is float64
70
  pixels = pixels.astype(np.float64)
71
  # Apply K-means clustering to find dominant colors
72
+ kmeans = KMeans(n_clusters=k, random_state=42, n_init=10, max_iter=300)
73
  kmeans.fit(pixels)
74
  # Convert normalized colors back to 0-255 scale
75
  colors = (kmeans.cluster_centers_ * 255).astype(int)
 
78
  # Create an Image for the Color Palette
79
  def create_palette_image(colors):
80
  num_colors = len(colors)
81
+ palette_height = 50
82
+ palette_width = 50 * num_colors
83
  palette_image = Image.new("RGB", (palette_width, palette_height))
84
 
85
  draw = ImageDraw.Draw(palette_image)
86
  for i, color in enumerate(colors):
87
  # Ensure color values are within the valid range and integers
88
  color = tuple(np.clip(color, 0, 255).astype(int))
89
+ draw.rectangle([i * 50, 0, (i + 1) * 50, palette_height], fill=color)
90
 
91
  return palette_image
92
 
 
102
 
103
  # Generate Image Caption Using Hugging Face BLIP
104
  def generate_caption(image):
105
+ inputs = blip_processor(images=image, return_tensors="pt")
106
+ output = blip_model.generate(**inputs)
107
+ caption = blip_processor.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
108
  return caption
109
 
110
  # Translate Caption to Arabic Using mBART
111
  def translate_to_arabic(text):
112
+ mbart_tokenizer.src_lang = "en_XX"
113
+ encoded = mbart_tokenizer(text, return_tensors="pt")
114
+ generated_tokens = mbart_model.generate(
 
 
 
 
 
 
 
115
  **encoded,
116
+ forced_bos_token_id=mbart_tokenizer.lang_code_to_id["ar_AR"]
117
  )
118
+ translated_text = mbart_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
119
  return translated_text
120
 
121
  # Gradio Interface Function (Combining Elements)
122
  def process_image(image):
123
+ try:
124
+ # Ensure input is a PIL Image
125
+ if isinstance(image, np.ndarray):
126
+ image = Image.fromarray(image)
127
+
128
+ # Convert to RGB format for PIL processing
129
+ image_rgb = image.convert("RGB")
130
+
131
+ # Load and resize the entire image
132
+ resized_image_np = load_image(image_rgb)
133
+
134
+ # Convert resized image to PIL Image for Gradio output
135
+ resized_image_pil = Image.fromarray(resized_image_np)
136
+
137
+ # Generate caption using BLIP model
138
+ caption = generate_caption(image_rgb)
139
+
140
+ # Translate caption to Arabic
141
+ caption_arabic = translate_to_arabic(caption)
142
+
143
+ # Extract dominant colors from the entire image
144
+ colors = extract_colors(resized_image_np, k=8)
145
+ color_palette = display_palette(colors)
146
+
147
+ # Create palette image
148
+ palette_image = create_palette_image(colors)
149
+
150
+ # Combine English and Arabic captions
151
+ bilingual_caption = f"English: {caption}\nArabic: {caption_arabic}"
152
+
153
+ return bilingual_caption, ", ".join(color_palette), palette_image, resized_image_pil
154
+ except Exception as e:
155
+ print(f"Error during processing: {e}")
156
+ return "An error occurred during processing.", "", None, None
157
 
158
  # Create Gradio Interface using Blocks and add a submit button
159
  with gr.Blocks(css=".gradio-container { height: 1000px !important; }") as demo: