tournas commited on
Commit
483fc16
·
verified ·
1 Parent(s): 86ad5b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -80
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import uuid
3
  import gradio as gr
4
  import torch
5
  import nltk
@@ -9,125 +8,78 @@ from diffusers import StableDiffusionPipeline
9
  from ultralytics import YOLO
10
  from gtts import gTTS
11
  from PIL import Image
 
12
  from nltk.tokenize import sent_tokenize
 
13
  import spaces
14
 
15
- # Set device (use GPU if available, but don't initialize CUDA here)
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
- # Load environment variables
19
  api_key = os.getenv("OPENAI_API_KEY")
20
  if not api_key:
21
- raise ValueError("⚠️ OpenAI API Key is missing! Add it as a Secret in Hugging Face Spaces.")
22
 
23
- # Initialize OpenAI client
24
  client = OpenAI(api_key=api_key)
25
 
26
- # Download NLTK data
 
 
27
  nltk.download("punkt")
28
 
29
- @spaces.GPU
30
- # Lazy-load models to avoid initializing CUDA in the main process
31
- def load_yolo_model():
32
- return YOLO("yolov8s.pt")
33
-
34
- def load_stable_diffusion():
35
- return StableDiffusionPipeline.from_pretrained(
36
- "runwayml/stable-diffusion-v1-5",
37
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
38
- ).to(device)
39
 
40
- def load_summarizer():
41
- return pipeline("summarization", model="facebook/bart-large-cnn", device=0 if device == "cuda" else -1)
42
-
43
- # Function to detect objects in an image
44
- def detect_objects(image_path, yolo_model):
45
  results = yolo_model(image_path)
46
  detected_objects = []
47
  for r in results:
48
  for box in r.boxes:
49
- class_id = int(box.cls.item())
50
  label = yolo_model.names[class_id]
51
  detected_objects.append(label)
52
  return detected_objects
53
 
54
- # Function to generate a story based on detected objects
55
  def generate_story(detected_objects):
56
  story_prompt = f"Write a short story based on the following objects: {', '.join(detected_objects)}"
57
- response = client.chat.completions.create(
58
- model="gpt-4", # Use GPT-4 or GPT-3.5-turbo
59
  messages=[{"role": "user", "content": story_prompt}],
60
  max_tokens=200
61
  )
62
- return response.choices[0].message.content.strip()
63
 
64
- # Function to summarize the story into scenes
65
- def summarize_story(story, summarizer):
66
  summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text']
67
  scenes = sent_tokenize(summary)
68
  return scenes
69
 
70
- # Function to generate images for each scene
71
- def generate_images(story, stable_diffusion):
72
  scenes = summarize_story(story)
73
  prompts = [f"Highly detailed, cinematic scene: {scene}, digital art, 4K, realistic lighting" for scene in scenes]
74
  images = []
75
  for prompt in prompts:
76
- image = stable_diffusion(prompt=prompt).images[0]
77
  images.append(image)
78
  return images
79
 
80
- # Function to convert text to speech
81
  def text_to_speech(story):
82
  tts = gTTS(text=story, lang="en", slow=False)
83
- audio_file_path = f"story_audio_{uuid.uuid4().hex}.mp3" # Unique filename
84
  tts.save(audio_file_path)
85
  return audio_file_path
86
 
87
- # Main pipeline function
88
  def full_pipeline(image):
89
- try:
90
- # Save the image with a unique filename
91
- image_path = f"temp_{uuid.uuid4().hex}.jpg"
92
- image.save(image_path)
93
-
94
- # Lazy-load models
95
- yolo_model = load_yolo_model()
96
- stable_diffusion = load_stable_diffusion()
97
- summarizer = load_summarizer()
98
-
99
- # Detect objects in the image
100
- detected_objects = detect_objects(image_path, yolo_model)
101
- if not detected_objects:
102
- return "No objects detected. Please upload a different image.", "", [], None
103
-
104
- # Generate a story based on detected objects
105
- story = generate_story(detected_objects)
106
- if not story:
107
- return "Failed to generate a story. Please try again.", "", [], None
108
-
109
- # Summarize the story into scenes
110
- scenes = summarize_story(story, summarizer)
111
- if not scenes:
112
- return story, "No scenes extracted.", [], None
113
-
114
- # Generate images for each scene
115
- images = generate_images(story, stable_diffusion)
116
- if not images:
117
- return story, "\n".join(scenes), [], None
118
-
119
- # Convert the story to audio
120
- audio = text_to_speech(story)
121
- if not audio:
122
- return story, "\n".join(scenes), images, None
123
-
124
- # Return all outputs
125
- return story, "\n".join(scenes), images, audio
126
-
127
- except Exception as e:
128
- return f"An error occurred: {str(e)}", "", [], None
129
 
130
- # Gradio UI with queue for long-running tasks
131
  demo = gr.Interface(
132
  fn=full_pipeline,
133
  inputs=gr.Image(type="pil"),
@@ -141,9 +93,5 @@ demo = gr.Interface(
141
  description="Upload an image, and the AI will detect objects, generate a story, create images, and narrate the story."
142
  )
143
 
144
- # Enable queue for long-running tasks
145
- demo.queue()
146
-
147
- # Launch the app
148
  if __name__ == "__main__":
149
- demo.launch()
 
1
  import os
 
2
  import gradio as gr
3
  import torch
4
  import nltk
 
8
  from ultralytics import YOLO
9
  from gtts import gTTS
10
  from PIL import Image
11
+ import numpy as np
12
  from nltk.tokenize import sent_tokenize
13
+ from IPython.display import Audio
14
  import spaces
15
 
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
 
18
  api_key = os.getenv("OPENAI_API_KEY")
19
  if not api_key:
20
+ raise ValueError("\u26a0\ufe0f OpenAI API Key is missing! Add it as a Secret in Hugging Face Spaces.")
21
 
 
22
  client = OpenAI(api_key=api_key)
23
 
24
+ yolo_model = YOLO("yolov8s.pt")
25
+ stable_diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
26
+ stable_diffusion.to(device)
27
  nltk.download("punkt")
28
 
29
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if torch.cuda.is_available() else -1)
 
 
 
 
 
 
 
 
 
30
 
31
+ @spaces.GPU
32
+ def detect_objects(image_path):
 
 
 
33
  results = yolo_model(image_path)
34
  detected_objects = []
35
  for r in results:
36
  for box in r.boxes:
37
+ class_id = int(box.cls.item()) # Διορθώθηκε
38
  label = yolo_model.names[class_id]
39
  detected_objects.append(label)
40
  return detected_objects
41
 
 
42
  def generate_story(detected_objects):
43
  story_prompt = f"Write a short story based on the following objects: {', '.join(detected_objects)}"
44
+ response = client.completions.create(
45
+ model="gpt-4o-mini",
46
  messages=[{"role": "user", "content": story_prompt}],
47
  max_tokens=200
48
  )
49
+ return response.choices[0].text.strip() # Διορθώθηκε
50
 
51
+ def summarize_story(story):
 
52
  summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text']
53
  scenes = sent_tokenize(summary)
54
  return scenes
55
 
56
+ def generate_images(story):
 
57
  scenes = summarize_story(story)
58
  prompts = [f"Highly detailed, cinematic scene: {scene}, digital art, 4K, realistic lighting" for scene in scenes]
59
  images = []
60
  for prompt in prompts:
61
+ image = stable_diffusion(prompt=prompt).images[0] # Διορθώθηκε
62
  images.append(image)
63
  return images
64
 
 
65
  def text_to_speech(story):
66
  tts = gTTS(text=story, lang="en", slow=False)
67
+ audio_file_path = "story_audio.mp3"
68
  tts.save(audio_file_path)
69
  return audio_file_path
70
 
 
71
  def full_pipeline(image):
72
+ image_path = "input.jpg"
73
+ image.save(image_path) # Διορθώθηκε
74
+ detected_objects = detect_objects(image_path)
75
+ story = generate_story(detected_objects)
76
+ scenes = summarize_story(story)
77
+ images = generate_images(story)
78
+ audio = text_to_speech(story)
79
+
80
+ return story, scenes, images, audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ # **Gradio UI**
83
  demo = gr.Interface(
84
  fn=full_pipeline,
85
  inputs=gr.Image(type="pil"),
 
93
  description="Upload an image, and the AI will detect objects, generate a story, create images, and narrate the story."
94
  )
95
 
 
 
 
 
96
  if __name__ == "__main__":
97
+ demo.launch()