1Noura commited on
Commit
1cc6afc
·
verified ·
1 Parent(s): 94e95af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -42
app.py CHANGED
@@ -1,43 +1,25 @@
1
  import gradio as gr
2
- import wget
3
  from transformers import pipeline
4
  from diffusers import StableDiffusionPipeline
5
  import torch
6
- import time
7
 
8
  # Define the device to use (either "cuda" for GPU or "cpu" for CPU)
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- # Function to load the models
12
- def load_models():
13
- global caption_image, sd_pipeline, translator
14
- start_time = time.time()
15
-
16
- # Load the image captioning model
17
- caption_image = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large", device=device)
18
- print(f"Caption model loaded in {time.time() - start_time:.2f} seconds")
19
-
20
- # Load the Stable Diffusion model with low CPU memory usage
21
- sd_pipeline = StableDiffusionPipeline.from_pretrained(
22
- "runwayml/stable-diffusion-v1-5",
23
- low_cpu_mem_usage=True # Enable low CPU memory usage
24
- ).to(device)
25
- print(f"Stable Diffusion model loaded in {time.time() - start_time:.2f} seconds")
26
-
27
- # Load the translation model
28
- translator = pipeline(
29
- task="translation",
30
- model="facebook/nllb-200-distilled-600M",
31
- device=device
32
- )
33
- print(f"Translator model loaded in {time.time() - start_time:.2f} seconds")
34
-
35
  # Load the models
36
- load_models()
37
-
38
- # Download the images
39
- url1 = "https://github.com/Shahad-b/Image-database/blob/main/sea.jpg?raw=true"
40
- sea = wget.download(url1)
 
 
 
 
 
 
 
41
 
42
  # Function to generate images based on the image's caption
43
  def generate_image_and_translate(image, num_images=1):
@@ -59,23 +41,22 @@ def generate_image_and_translate(image, num_images=1):
59
 
60
  # Set up the Gradio interface
61
  interface = gr.Interface(
62
- fn=generate_image_and_translate,
63
  inputs=[
64
- gr.Image(type="pil", label="Upload Image"),
65
- gr.Slider(minimum=1, maximum=10, label="Number of Images", value=1, step=1)
66
  ],
67
  outputs=[
68
- gr.Gallery(label="Generated Images"),
69
- gr.Textbox(label="Generated Caption (English)", interactive=False),
70
- gr.Textbox(label="Translated Caption (Arabic)", interactive=False)
 
71
  ],
72
- title="Image Generation and Translation",
73
- description="Upload an image to generate new images based on its caption and translate the caption into Arabic.",
74
- examples=[
75
- ["sea.jpg", 3]
76
- ]
77
  )
78
 
 
79
  # Launch the Gradio application within the main guard
80
  if __name__ == "__main__":
81
  interface.launch()
 
1
  import gradio as gr
 
2
  from transformers import pipeline
3
  from diffusers import StableDiffusionPipeline
4
  import torch
5
+ import wget
6
 
7
  # Define the device to use (either "cuda" for GPU or "cpu" for CPU)
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # Load the models
11
+ # Image captioning model to generate captions from uploaded images
12
+ caption_image = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large", device=device)
13
+ # Stable Diffusion model for generating new images based on captions
14
+ sd_pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
15
+
16
+ # Load the translation model (English to Arabic)
17
+ translator = pipeline(
18
+ task="translation",
19
+ model="facebook/nllb-200-distilled-600M",
20
+ torch_dtype=torch.bfloat16,
21
+ device=device
22
+ )
23
 
24
  # Function to generate images based on the image's caption
25
  def generate_image_and_translate(image, num_images=1):
 
41
 
42
  # Set up the Gradio interface
43
  interface = gr.Interface(
44
+ fn=generate_image_and_translate, # Function to call when processing input
45
  inputs=[
46
+ gr.Image(type="pil", label="Upload Image"), # Input for image upload
47
+ gr.Slider(minimum=1, maximum=10, label="Number of Images", value=1, step=1) # Slider to select number of images
48
  ],
49
  outputs=[
50
+ gr.Gallery(label="Generated Images"), # Output for displaying generated images
51
+ gr.Textbox(label="Generated Caption (English)", interactive=False), # Output for English caption
52
+ gr.Textbox(label="Translated Caption (Arabic)", interactive=False)# Output for Arabic caption
53
+
54
  ],
55
+ title="Image Generation and Translation", # Title of the interface
56
+ description="Upload an image to generate new images based on its caption and translate the caption into Arabic.", # Description
 
 
 
57
  )
58
 
59
+
60
  # Launch the Gradio application within the main guard
61
  if __name__ == "__main__":
62
  interface.launch()