tournas commited on
Commit
fca69f9
·
verified ·
1 Parent(s): 7b64947

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -20
app.py CHANGED
@@ -10,9 +10,8 @@ from ultralytics import YOLO
10
  from gtts import gTTS
11
  from PIL import Image
12
  from nltk.tokenize import sent_tokenize
13
- import spaces
14
 
15
- # Set device (use GPU if available)
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  # Load environment variables
@@ -23,22 +22,24 @@ if not api_key:
23
  # Initialize OpenAI client
24
  client = OpenAI(api_key=api_key)
25
 
26
- # Load YOLO model
27
- yolo_model = YOLO("yolov8s.pt")
28
-
29
- # Load Stable Diffusion pipeline
30
- stable_diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if device == "cuda" else torch.float32)
31
- stable_diffusion.to(device)
32
-
33
  # Download NLTK data
34
  nltk.download("punkt")
35
 
36
- # Load summarization pipeline
37
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if device == "cuda" else -1)
 
 
 
 
 
 
 
 
 
 
38
 
39
- @spaces.GPU
40
  # Function to detect objects in an image
41
- def detect_objects(image_path):
42
  results = yolo_model(image_path)
43
  detected_objects = []
44
  for r in results:
@@ -52,20 +53,20 @@ def detect_objects(image_path):
52
  def generate_story(detected_objects):
53
  story_prompt = f"Write a short story based on the following objects: {', '.join(detected_objects)}"
54
  response = client.chat.completions.create(
55
- model="gpt-4o-mini", # Use GPT-4 or GPT-3.5-turbo
56
  messages=[{"role": "user", "content": story_prompt}],
57
  max_tokens=200
58
  )
59
  return response.choices[0].message.content.strip()
60
 
61
  # Function to summarize the story into scenes
62
- def summarize_story(story):
63
  summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text']
64
  scenes = sent_tokenize(summary)
65
  return scenes
66
 
67
  # Function to generate images for each scene
68
- def generate_images(story):
69
  scenes = summarize_story(story)
70
  prompts = [f"Highly detailed, cinematic scene: {scene}, digital art, 4K, realistic lighting" for scene in scenes]
71
  images = []
@@ -88,8 +89,13 @@ def full_pipeline(image):
88
  image_path = f"temp_{uuid.uuid4().hex}.jpg"
89
  image.save(image_path)
90
 
 
 
 
 
 
91
  # Detect objects in the image
92
- detected_objects = detect_objects(image_path)
93
  if not detected_objects:
94
  return "No objects detected. Please upload a different image.", "", [], None
95
 
@@ -99,12 +105,12 @@ def full_pipeline(image):
99
  return "Failed to generate a story. Please try again.", "", [], None
100
 
101
  # Summarize the story into scenes
102
- scenes = summarize_story(story)
103
  if not scenes:
104
  return story, "No scenes extracted.", [], None
105
 
106
  # Generate images for each scene
107
- images = generate_images(story)
108
  if not images:
109
  return story, "\n".join(scenes), [], None
110
 
@@ -119,7 +125,7 @@ def full_pipeline(image):
119
  except Exception as e:
120
  return f"An error occurred: {str(e)}", "", [], None
121
 
122
- # Gradio UI
123
  demo = gr.Interface(
124
  fn=full_pipeline,
125
  inputs=gr.Image(type="pil"),
@@ -133,6 +139,9 @@ demo = gr.Interface(
133
  description="Upload an image, and the AI will detect objects, generate a story, create images, and narrate the story."
134
  )
135
 
 
 
 
136
  # Launch the app
137
  if __name__ == "__main__":
138
  demo.launch()
 
10
  from gtts import gTTS
11
  from PIL import Image
12
  from nltk.tokenize import sent_tokenize
 
13
 
14
+ # Set device (use GPU if available, but don't initialize CUDA here)
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  # Load environment variables
 
22
  # Initialize OpenAI client
23
  client = OpenAI(api_key=api_key)
24
 
 
 
 
 
 
 
 
25
  # Download NLTK data
26
  nltk.download("punkt")
27
 
28
+ # Lazy-load models to avoid initializing CUDA in the main process
29
+ def load_yolo_model():
30
+ return YOLO("yolov8s.pt")
31
+
32
+ def load_stable_diffusion():
33
+ return StableDiffusionPipeline.from_pretrained(
34
+ "runwayml/stable-diffusion-v1-5",
35
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
36
+ ).to(device)
37
+
38
+ def load_summarizer():
39
+ return pipeline("summarization", model="facebook/bart-large-cnn", device=0 if device == "cuda" else -1)
40
 
 
41
  # Function to detect objects in an image
42
+ def detect_objects(image_path, yolo_model):
43
  results = yolo_model(image_path)
44
  detected_objects = []
45
  for r in results:
 
53
  def generate_story(detected_objects):
54
  story_prompt = f"Write a short story based on the following objects: {', '.join(detected_objects)}"
55
  response = client.chat.completions.create(
56
+ model="gpt-4", # Use GPT-4 or GPT-3.5-turbo
57
  messages=[{"role": "user", "content": story_prompt}],
58
  max_tokens=200
59
  )
60
  return response.choices[0].message.content.strip()
61
 
62
  # Function to summarize the story into scenes
63
+ def summarize_story(story, summarizer):
64
  summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text']
65
  scenes = sent_tokenize(summary)
66
  return scenes
67
 
68
  # Function to generate images for each scene
69
+ def generate_images(story, stable_diffusion):
70
  scenes = summarize_story(story)
71
  prompts = [f"Highly detailed, cinematic scene: {scene}, digital art, 4K, realistic lighting" for scene in scenes]
72
  images = []
 
89
  image_path = f"temp_{uuid.uuid4().hex}.jpg"
90
  image.save(image_path)
91
 
92
+ # Lazy-load models
93
+ yolo_model = load_yolo_model()
94
+ stable_diffusion = load_stable_diffusion()
95
+ summarizer = load_summarizer()
96
+
97
  # Detect objects in the image
98
+ detected_objects = detect_objects(image_path, yolo_model)
99
  if not detected_objects:
100
  return "No objects detected. Please upload a different image.", "", [], None
101
 
 
105
  return "Failed to generate a story. Please try again.", "", [], None
106
 
107
  # Summarize the story into scenes
108
+ scenes = summarize_story(story, summarizer)
109
  if not scenes:
110
  return story, "No scenes extracted.", [], None
111
 
112
  # Generate images for each scene
113
+ images = generate_images(story, stable_diffusion)
114
  if not images:
115
  return story, "\n".join(scenes), [], None
116
 
 
125
  except Exception as e:
126
  return f"An error occurred: {str(e)}", "", [], None
127
 
128
+ # Gradio UI with queue for long-running tasks
129
  demo = gr.Interface(
130
  fn=full_pipeline,
131
  inputs=gr.Image(type="pil"),
 
139
  description="Upload an image, and the AI will detect objects, generate a story, create images, and narrate the story."
140
  )
141
 
142
+ # Enable queue for long-running tasks
143
+ demo.queue()
144
+
145
  # Launch the app
146
  if __name__ == "__main__":
147
  demo.launch()