1Noura commited on
Commit
95953e3
·
verified ·
1 Parent(s): a602863

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -3,28 +3,40 @@ import wget
3
  from transformers import pipeline
4
  from diffusers import StableDiffusionPipeline
5
  import torch
 
6
 
7
  # Define the device to use (either "cuda" for GPU or "cpu" for CPU)
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
- # Load the models
11
- caption_image = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large", device=device)
12
- sd_pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
 
13
 
14
- # Load the translation model (English to Arabic)
15
- translator = pipeline(
16
- task="translation",
17
- model="facebook/nllb-200-distilled-600M",
18
- torch_dtype=torch.bfloat16,
19
- device=device
20
- )
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Download the images
23
  url1 = "https://github.com/Shahad-b/Image-database/blob/main/sea.jpg?raw=true"
24
  sea = wget.download(url1)
25
 
26
-
27
-
28
  # Function to generate images based on the image's caption
29
  def generate_image_and_translate(image, num_images=1):
30
  # Generate caption in English from the uploaded image
 
3
  from transformers import pipeline
4
  from diffusers import StableDiffusionPipeline
5
  import torch
6
+ import time
7
 
8
  # Define the device to use (either "cuda" for GPU or "cpu" for CPU)
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ # Function to load the models
12
+ def load_models():
13
+ global caption_image, sd_pipeline, translator
14
+ start_time = time.time()
15
 
16
+ # Load the image captioning model
17
+ caption_image = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large", device=device)
18
+ print(f"Caption model loaded in {time.time() - start_time:.2f} seconds")
19
+
20
+ # Load the Stable Diffusion model
21
+ sd_pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
22
+ print(f"Stable Diffusion model loaded in {time.time() - start_time:.2f} seconds")
23
+
24
+ # Load the translation model (English to Arabic)
25
+ translator = pipeline(
26
+ task="translation",
27
+ model="facebook/nllb-200-distilled-600M",
28
+ torch_dtype=torch.bfloat16,
29
+ device=device
30
+ )
31
+ print(f"Translator model loaded in {time.time() - start_time:.2f} seconds")
32
+
33
+ # Load the models
34
+ load_models()
35
 
36
  # Download the images
37
  url1 = "https://github.com/Shahad-b/Image-database/blob/main/sea.jpg?raw=true"
38
  sea = wget.download(url1)
39
 
 
 
40
  # Function to generate images based on the image's caption
41
  def generate_image_and_translate(image, num_images=1):
42
  # Generate caption in English from the uploaded image