tournas commited on
Commit
1aaa563
·
verified ·
1 Parent(s): 16a5d71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -49
app.py CHANGED
@@ -12,6 +12,7 @@ import numpy as np
12
  from nltk.tokenize import sent_tokenize
13
  import spaces
14
 
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  api_key = os.getenv("OPENAI_API_KEY")
@@ -20,17 +21,18 @@ if not api_key:
20
 
21
  client = OpenAI(api_key=api_key)
22
 
23
- yolo_model = YOLO("yolov8s.pt")
 
24
 
 
25
  stable_diffusion = StableDiffusionPipeline.from_pretrained(
26
  "runwayml/stable-diffusion-v1-5",
27
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
28
  ).to(device)
29
 
 
30
 
31
- stable_diffusion.vae.enable_tiling = False
32
-
33
- nltk.download("punkt")
34
 
35
  summarizer = pipeline(
36
  "summarization",
@@ -38,74 +40,105 @@ summarizer = pipeline(
38
  )
39
 
40
  def detect_objects(image):
41
- yolo_model.to('cuda')
42
- image_array = np.array(image) # Μετατροπή PIL → NumPy
43
- results = yolo_model(image_array)
44
- detected_objects = []
45
- for r in results:
46
- for box in r.boxes:
47
- class_id = int(box.cls.item())
48
- label = yolo_model.names[class_id]
49
- detected_objects.append(label)
50
- return detected_objects
 
 
 
 
 
51
 
52
  def generate_story(detected_objects):
53
- story_prompt = f"Write a short story based on the following objects: {', '.join(detected_objects)}"
54
- response = client.chat.completions.create(
55
- model="gpt-4o-mini",
56
- messages=[{"role": "user", "content": story_prompt}],
57
- max_tokens=200
58
- )
59
- return response.choices[0].message.content.strip()
 
 
 
 
60
 
61
  def summarize_story(story):
62
- summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text']
63
- scenes = sent_tokenize(summary)
64
- return scenes
 
 
 
 
65
 
66
  def generate_images(story):
67
  scenes = summarize_story(story)
68
  images = []
69
 
70
- # Περιορισμός σε μέγιστο 3 σκηνές για αποφυγή υπερφόρτωσης
71
- scenes = scenes[:min(len(scenes), 3)]
72
-
73
  for prompt in scenes:
74
  try:
75
- with torch.no_grad(): # Μειώνει τη χρήση μνήμης
76
- prompt_text = f"Highly detailed, cinematic scene: {prompt}, digital art, 4K, realistic lighting"
77
- # Προσθέτω παραμέτρους για καλύτερη διαχείριση μνήμης
78
  image = stable_diffusion(
79
  prompt_text,
80
- num_inference_steps=30, # Μείωση από το προεπιλεγμένο 50
81
- guidance_scale=7.5
 
 
82
  ).images[0]
83
  images.append(image)
84
- # Καθαρισμός μνήμης μετά από κάθε δημιουργία
 
85
  if torch.cuda.is_available():
86
  torch.cuda.empty_cache()
87
  except Exception as e:
88
- print(f"Error generating image for scene: {e}")
89
- # Συνέχισε με την επόμενη σκηνή σε περίπτ��ση σφάλματος
90
- continue
91
-
92
- return images
93
 
94
  def text_to_speech(story):
95
- tts = gTTS(text=story, lang="en", slow=False)
96
- audio_file_path = "story_audio.mp3"
97
- tts.save(audio_file_path)
98
- return audio_file_path
 
 
 
 
99
 
100
  @spaces.GPU
101
  def full_pipeline(image):
102
- detected_objects = detect_objects(image)
103
- story = generate_story(detected_objects)
104
- scenes = summarize_story(story)
105
- images = generate_images(story)
106
- audio = text_to_speech(story)
107
-
108
- return story, scenes, images, audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  # **Gradio UI**
111
  demo = gr.Interface(
 
12
  from nltk.tokenize import sent_tokenize
13
  import spaces
14
 
15
+ # Ensure minimal GPU usage
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  api_key = os.getenv("OPENAI_API_KEY")
 
21
 
22
  client = OpenAI(api_key=api_key)
23
 
24
+ # Use smallest YOLO model
25
+ yolo_model = YOLO("yolov8n.pt")
26
 
27
+ # Lightweight Stable Diffusion configuration
28
  stable_diffusion = StableDiffusionPipeline.from_pretrained(
29
  "runwayml/stable-diffusion-v1-5",
30
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
31
  ).to(device)
32
 
33
+ stable_diffusion.vae.enable_tiling = True # Enable tiling to reduce memory usage
34
 
35
+ nltk.download("punkt", quiet=True)
 
 
36
 
37
  summarizer = pipeline(
38
  "summarization",
 
40
  )
41
 
42
  def detect_objects(image):
43
+ try:
44
+ # Move model to appropriate device
45
+ yolo_model.to(device)
46
+ image_array = np.array(image)
47
+ results = yolo_model(image_array)
48
+ detected_objects = []
49
+ for r in results:
50
+ for box in r.boxes:
51
+ class_id = int(box.cls.item())
52
+ label = yolo_model.names[class_id]
53
+ detected_objects.append(label)
54
+ return list(set(detected_objects)) # Remove duplicates
55
+ except Exception as e:
56
+ print(f"Object detection error: {e}")
57
+ return ["generic", "objects"]
58
 
59
  def generate_story(detected_objects):
60
+ try:
61
+ story_prompt = f"Write a concise, creative short story using these objects: {', '.join(detected_objects)}"
62
+ response = client.chat.completions.create(
63
+ model="gpt-3.5-turbo", # More lightweight model
64
+ messages=[{"role": "user", "content": story_prompt}],
65
+ max_tokens=150 # Reduced token count
66
+ )
67
+ return response.choices[0].message.content.strip()
68
+ except Exception as e:
69
+ print(f"Story generation error: {e}")
70
+ return "A mysterious tale of adventure and discovery."
71
 
72
  def summarize_story(story):
73
+ try:
74
+ summary = summarizer(story, max_length=50, do_sample=False)[0]['summary_text']
75
+ scenes = sent_tokenize(summary)
76
+ return scenes[:2] # Limit to 2 scenes to reduce computational load
77
+ except Exception as e:
78
+ print(f"Story summarization error: {e}")
79
+ return ["A peaceful scene", "An exciting moment"]
80
 
81
  def generate_images(story):
82
  scenes = summarize_story(story)
83
  images = []
84
 
 
 
 
85
  for prompt in scenes:
86
  try:
87
+ with torch.no_grad():
88
+ # Simplified, less computationally intensive prompt
89
+ prompt_text = f"Simple illustration: {prompt}, soft colors"
90
  image = stable_diffusion(
91
  prompt_text,
92
+ num_inference_steps=20, # Reduced steps
93
+ guidance_scale=6.0, # Slightly lower guidance
94
+ height=256, # Smaller image
95
+ width=256
96
  ).images[0]
97
  images.append(image)
98
+
99
+ # Aggressive memory clearing
100
  if torch.cuda.is_available():
101
  torch.cuda.empty_cache()
102
  except Exception as e:
103
+ print(f"Image generation error: {e}")
104
+
105
+ # Fallback if no images generated
106
+ return images if images else [Image.new('RGB', (256, 256), color='lightgray')]
 
107
 
108
  def text_to_speech(story):
109
+ try:
110
+ tts = gTTS(text=story[:500], lang="en", slow=False) # Limit to first 500 chars
111
+ audio_file_path = "story_audio.mp3"
112
+ tts.save(audio_file_path)
113
+ return audio_file_path
114
+ except Exception as e:
115
+ print(f"Text-to-speech error: {e}")
116
+ return None
117
 
118
  @spaces.GPU
119
  def full_pipeline(image):
120
+ # Wrap entire process with error handling
121
+ try:
122
+ detected_objects = detect_objects(image)
123
+ story = generate_story(detected_objects)
124
+ scenes = summarize_story(story)
125
+ images = generate_images(story)
126
+ audio = text_to_speech(story)
127
+
128
+ return (
129
+ story or "A story could not be generated.",
130
+ scenes or ["Scene 1", "Scene 2"],
131
+ images,
132
+ audio
133
+ )
134
+ except Exception as e:
135
+ print(f"Full pipeline error: {e}")
136
+ return (
137
+ "An unexpected error occurred.",
138
+ ["Something went wrong"],
139
+ [Image.new('RGB', (256, 256), color='lightgray')],
140
+ None
141
+ )
142
 
143
  # **Gradio UI**
144
  demo = gr.Interface(