tournas commited on
Commit
80b7403
·
verified ·
1 Parent(s): 17a4149

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -22
app.py CHANGED
@@ -1,59 +1,70 @@
1
  import os
 
2
  import gradio as gr
3
  import torch
4
  import nltk
5
  from openai import OpenAI
6
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
7
  from diffusers import StableDiffusionPipeline
8
  from ultralytics import YOLO
9
  from gtts import gTTS
10
  from PIL import Image
11
- import numpy as np
12
  from nltk.tokenize import sent_tokenize
13
- from IPython.display import Audio
14
  import spaces
15
 
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
 
18
  api_key = os.getenv("OPENAI_API_KEY")
19
  if not api_key:
20
  raise ValueError("⚠️ OpenAI API Key is missing! Add it as a Secret in Hugging Face Spaces.")
21
 
22
-
23
  client = OpenAI(api_key=api_key)
24
 
25
-
26
  yolo_model = YOLO("yolov8s.pt")
27
- stable_diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
 
 
28
  stable_diffusion.to(device)
 
 
29
  nltk.download("punkt")
30
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device= 'cuda' if torch.cuda.is_available() else 'cpu')
 
 
31
 
32
  @spaces.GPU
 
33
  def detect_objects(image_path):
34
  results = yolo_model(image_path)
35
  detected_objects = []
36
  for r in results:
37
  for box in r.boxes:
38
- class_id = int(box.cls[0])
39
  label = yolo_model.names[class_id]
40
  detected_objects.append(label)
41
  return detected_objects
42
 
 
43
  def generate_story(detected_objects):
44
  story_prompt = f"Write a short story based on the following objects: {', '.join(detected_objects)}"
45
  response = client.chat.completions.create(
46
- model="gpt-4o-mini",
47
  messages=[{"role": "user", "content": story_prompt}],
48
  max_tokens=200
49
  )
50
  return response.choices[0].message.content.strip()
51
 
 
52
  def summarize_story(story):
53
  summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text']
54
  scenes = sent_tokenize(summary)
55
  return scenes
56
 
 
57
  def generate_images(story):
58
  scenes = summarize_story(story)
59
  prompts = [f"Highly detailed, cinematic scene: {scene}, digital art, 4K, realistic lighting" for scene in scenes]
@@ -63,24 +74,52 @@ def generate_images(story):
63
  images.append(image)
64
  return images
65
 
 
66
  def text_to_speech(story):
67
  tts = gTTS(text=story, lang="en", slow=False)
68
- audio_file_path = "story_audio.mp3"
69
  tts.save(audio_file_path)
70
  return audio_file_path
71
 
 
72
  def full_pipeline(image):
73
- image_path = "input.jpg"
74
- image.save(image_path)
75
- detected_objects = detect_objects(image_path)
76
- story = generate_story(detected_objects)
77
- scenes = summarize_story(story)
78
- images = generate_images(story)
79
- audio = text_to_speech(story)
80
-
81
- return story, scenes, images, audio
82
-
83
- # **Gradio UI**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  demo = gr.Interface(
85
  fn=full_pipeline,
86
  inputs=gr.Image(type="pil"),
@@ -94,6 +133,6 @@ demo = gr.Interface(
94
  description="Upload an image, and the AI will detect objects, generate a story, create images, and narrate the story."
95
  )
96
 
97
-
98
  if __name__ == "__main__":
99
  demo.launch()
 
1
  import os
2
+ import uuid
3
  import gradio as gr
4
  import torch
5
  import nltk
6
  from openai import OpenAI
7
+ from transformers import pipeline
8
  from diffusers import StableDiffusionPipeline
9
  from ultralytics import YOLO
10
  from gtts import gTTS
11
  from PIL import Image
 
12
  from nltk.tokenize import sent_tokenize
 
13
  import spaces
14
 
15
+ # Set device (use GPU if available)
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ # Load environment variables
19
  api_key = os.getenv("OPENAI_API_KEY")
20
  if not api_key:
21
  raise ValueError("⚠️ OpenAI API Key is missing! Add it as a Secret in Hugging Face Spaces.")
22
 
23
+ # Initialize OpenAI client
24
  client = OpenAI(api_key=api_key)
25
 
26
+ # Load YOLO model
27
  yolo_model = YOLO("yolov8s.pt")
28
+
29
+ # Load Stable Diffusion pipeline
30
+ stable_diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if device == "cuda" else torch.float32)
31
  stable_diffusion.to(device)
32
+
33
+ # Download NLTK data
34
  nltk.download("punkt")
35
+
36
+ # Load summarization pipeline
37
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if device == "cuda" else -1)
38
 
39
  @spaces.GPU
40
+ # Function to detect objects in an image
41
  def detect_objects(image_path):
42
  results = yolo_model(image_path)
43
  detected_objects = []
44
  for r in results:
45
  for box in r.boxes:
46
+ class_id = int(box.cls.item())
47
  label = yolo_model.names[class_id]
48
  detected_objects.append(label)
49
  return detected_objects
50
 
51
+ # Function to generate a story based on detected objects
52
  def generate_story(detected_objects):
53
  story_prompt = f"Write a short story based on the following objects: {', '.join(detected_objects)}"
54
  response = client.chat.completions.create(
55
+ model="gpt-4", # Use GPT-4 or GPT-3.5-turbo
56
  messages=[{"role": "user", "content": story_prompt}],
57
  max_tokens=200
58
  )
59
  return response.choices[0].message.content.strip()
60
 
61
+ # Function to summarize the story into scenes
62
  def summarize_story(story):
63
  summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text']
64
  scenes = sent_tokenize(summary)
65
  return scenes
66
 
67
+ # Function to generate images for each scene
68
  def generate_images(story):
69
  scenes = summarize_story(story)
70
  prompts = [f"Highly detailed, cinematic scene: {scene}, digital art, 4K, realistic lighting" for scene in scenes]
 
74
  images.append(image)
75
  return images
76
 
77
+ # Function to convert text to speech
78
  def text_to_speech(story):
79
  tts = gTTS(text=story, lang="en", slow=False)
80
+ audio_file_path = f"story_audio_{uuid.uuid4().hex}.mp3" # Unique filename
81
  tts.save(audio_file_path)
82
  return audio_file_path
83
 
84
+ # Main pipeline function
85
  def full_pipeline(image):
86
+ try:
87
+ # Save the image with a unique filename
88
+ image_path = f"temp_{uuid.uuid4().hex}.jpg"
89
+ image.save(image_path)
90
+
91
+ # Detect objects in the image
92
+ detected_objects = detect_objects(image_path)
93
+ if not detected_objects:
94
+ return "No objects detected. Please upload a different image.", "", [], None
95
+
96
+ # Generate a story based on detected objects
97
+ story = generate_story(detected_objects)
98
+ if not story:
99
+ return "Failed to generate a story. Please try again.", "", [], None
100
+
101
+ # Summarize the story into scenes
102
+ scenes = summarize_story(story)
103
+ if not scenes:
104
+ return story, "No scenes extracted.", [], None
105
+
106
+ # Generate images for each scene
107
+ images = generate_images(story)
108
+ if not images:
109
+ return story, "\n".join(scenes), [], None
110
+
111
+ # Convert the story to audio
112
+ audio = text_to_speech(story)
113
+ if not audio:
114
+ return story, "\n".join(scenes), images, None
115
+
116
+ # Return all outputs
117
+ return story, "\n".join(scenes), images, audio
118
+
119
+ except Exception as e:
120
+ return f"An error occurred: {str(e)}", "", [], None
121
+
122
+ # Gradio UI
123
  demo = gr.Interface(
124
  fn=full_pipeline,
125
  inputs=gr.Image(type="pil"),
 
133
  description="Upload an image, and the AI will detect objects, generate a story, create images, and narrate the story."
134
  )
135
 
136
+ # Launch the app
137
  if __name__ == "__main__":
138
  demo.launch()