ali-kanbar commited on
Commit
68705f6
·
verified ·
1 Parent(s): f450b5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -31
app.py CHANGED
@@ -10,9 +10,6 @@ from functools import partial
10
  import torch
11
  import imageio
12
  import cv2
13
- from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
14
- from huggingface_hub import hf_hub_download
15
- from safetensors.torch import load_file
16
  from PIL import Image
17
  import edge_tts
18
  from transformers import AutoTokenizer, pipeline
@@ -29,6 +26,11 @@ text_pipe = pipeline(
29
  # Initialize the sentiment analyzer
30
  sentiment_analyzer = pipeline("sentiment-analysis")
31
 
 
 
 
 
 
32
  # Initialize video generation components
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
@@ -37,13 +39,27 @@ repo = "ByteDance/AnimateDiff-Lightning"
37
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
38
  base = "emilianJR/epiCRealism"
39
 
40
- # Load motion adapter
41
- adapter = MotionAdapter().to(device, dtype)
42
- adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # Load pipeline
45
- pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
46
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
47
 
48
  # Define all required functions
49
  def summarize(text):
@@ -141,7 +157,7 @@ def generate_story(prompt):
141
 
142
  full_output = generated[0]['generated_text']
143
  story = full_output.split("assistant\n")[-1].strip()
144
-
145
  # Process sentences and check constraints
146
  sentences = []
147
  for s in story.split('.'):
@@ -219,6 +235,15 @@ def generate_story(prompt):
219
  return final_story
220
 
221
  def generate_video(summary):
 
 
 
 
 
 
 
 
 
222
  def crossfade_transition(frames1, frames2, transition_length=10):
223
  blended_frames = []
224
  frames1_np = [np.array(frame) for frame in frames1[-transition_length:]]
@@ -241,6 +266,12 @@ def generate_video(summary):
241
  sentences = [s.strip() for s in sentences if s.strip()]
242
  print(f"Total scenes: {len(sentences)}")
243
 
 
 
 
 
 
 
244
  # Output config
245
  output_dir = "generated_frames"
246
  video_path = "generated_video.mp4"
@@ -256,23 +287,32 @@ def generate_video(summary):
256
  batch_prompts = sentences[i : i + batch_size]
257
  for idx, prompt in enumerate(batch_prompts):
258
  print(f"Generating animation for prompt {i+idx+1}/{len(sentences)}: {prompt}")
259
- output = pipe(
260
- prompt=prompt,
261
- guidance_scale=1.0,
262
- num_inference_steps=step,
263
- width=256,
264
- height=256,
265
- )
266
- frames = output.frames[0]
267
-
268
- if previous_frames is not None:
269
- transition = crossfade_transition(previous_frames, frames, transition_frames)
270
- all_frames.extend(transition)
271
-
272
- all_frames.extend(frames)
273
- previous_frames = frames
 
 
 
 
 
274
 
275
  # Save video
 
 
 
 
276
  imageio.mimsave(video_path, all_frames, fps=8)
277
  print(f"Video saved at {video_path}")
278
  return video_path
@@ -434,15 +474,10 @@ EXAMPLE_PROMPTS = [
434
  "A struggling local restaurant owner finds an innovative way to save their business during an economic downturn.",
435
  "An environmental scientist tracks mysterious wildlife behavior that reveals concerning climate changes.",
436
  "A community comes together to rebuild after a devastating natural disaster.",
437
- "A teacher develops a unique method that transforms learning for students with special needs.",
438
- "An elderly person reconnects with a childhood friend through social media after sixty years apart.",
439
- "A food delivery driver forms an unexpected friendship with an isolated elderly customer during the pandemic.",
440
- "A first-generation college student overcomes significant obstacles to achieve academic success.",
441
- "A wildlife photographer documents the surprising recovery of an endangered species."
442
  ]
443
 
444
  # Create the Gradio interface
445
- with gr.Blocks(title="Animind AI Story Video Generator", theme=gr.themes.Soft()) as demo:
446
  gr.Markdown("# 🎬 AI Story Video Generator")
447
  gr.Markdown("Enter a one-sentence prompt to generate a complete story with video and narration.")
448
 
@@ -503,6 +538,11 @@ with gr.Blocks(title="Animind AI Story Video Generator", theme=gr.themes.Soft())
503
  - Include interesting characters, settings, or situations
504
  - Make your prompt realistic but with potential for development
505
  - Try to suggest a potential conflict or discovery
 
 
 
 
 
506
 
507
  ## Troubleshooting
508
 
 
10
  import torch
11
  import imageio
12
  import cv2
 
 
 
13
  from PIL import Image
14
  import edge_tts
15
  from transformers import AutoTokenizer, pipeline
 
26
  # Initialize the sentiment analyzer
27
  sentiment_analyzer = pipeline("sentiment-analysis")
28
 
29
+ # Load diffusers libraries after tokenizer to avoid GPU memory conflicts
30
+ from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
31
+ from huggingface_hub import hf_hub_download
32
+ from safetensors.torch import load_file
33
+
34
  # Initialize video generation components
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
39
  ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
40
  base = "emilianJR/epiCRealism"
41
 
42
+ print(f"Using device: {device} with dtype: {dtype}")
43
+
44
+ # Load motion adapter and pipeline in a function to handle errors gracefully
45
+ def load_models():
46
+ try:
47
+ print("Loading motion adapter...")
48
+ adapter = MotionAdapter().to(device, dtype)
49
+ adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
50
+
51
+ print("Loading diffusion pipeline...")
52
+ pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
53
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
54
+
55
+ return adapter, pipe
56
+ except Exception as e:
57
+ print(f"Error loading models: {str(e)}")
58
+ traceback.print_exc()
59
+ return None, None
60
 
61
+ # We'll load the models on first use to avoid startup errors
62
+ adapter, pipe = None, None
 
63
 
64
  # Define all required functions
65
  def summarize(text):
 
157
 
158
  full_output = generated[0]['generated_text']
159
  story = full_output.split("assistant\n")[-1].strip()
160
+
161
  # Process sentences and check constraints
162
  sentences = []
163
  for s in story.split('.'):
 
235
  return final_story
236
 
237
  def generate_video(summary):
238
+ global adapter, pipe
239
+
240
+ # Load models if not already loaded
241
+ if adapter is None or pipe is None:
242
+ adapter, pipe = load_models()
243
+
244
+ if adapter is None or pipe is None:
245
+ raise Exception("Failed to load models. Please check the logs for errors.")
246
+
247
  def crossfade_transition(frames1, frames2, transition_length=10):
248
  blended_frames = []
249
  frames1_np = [np.array(frame) for frame in frames1[-transition_length:]]
 
266
  sentences = [s.strip() for s in sentences if s.strip()]
267
  print(f"Total scenes: {len(sentences)}")
268
 
269
+ # For development/testing purposes, limit the number of sentences
270
+ max_sentences = 5
271
+ if len(sentences) > max_sentences:
272
+ print(f"Limiting to first {max_sentences} sentences for faster testing")
273
+ sentences = sentences[:max_sentences]
274
+
275
  # Output config
276
  output_dir = "generated_frames"
277
  video_path = "generated_video.mp4"
 
287
  batch_prompts = sentences[i : i + batch_size]
288
  for idx, prompt in enumerate(batch_prompts):
289
  print(f"Generating animation for prompt {i+idx+1}/{len(sentences)}: {prompt}")
290
+ try:
291
+ output = pipe(
292
+ prompt=prompt,
293
+ guidance_scale=1.0,
294
+ num_inference_steps=step,
295
+ width=256,
296
+ height=256,
297
+ )
298
+ frames = output.frames[0]
299
+
300
+ if previous_frames is not None:
301
+ transition = crossfade_transition(previous_frames, frames, transition_frames)
302
+ all_frames.extend(transition)
303
+
304
+ all_frames.extend(frames)
305
+ previous_frames = frames
306
+ except Exception as e:
307
+ print(f"Error generating frames for prompt: {prompt}")
308
+ print(f"Error details: {str(e)}")
309
+ # Continue with next prompt if one fails
310
 
311
  # Save video
312
+ if not all_frames:
313
+ raise Exception("No frames were generated. Video creation failed.")
314
+
315
+ print(f"Saving video with {len(all_frames)} frames")
316
  imageio.mimsave(video_path, all_frames, fps=8)
317
  print(f"Video saved at {video_path}")
318
  return video_path
 
474
  "A struggling local restaurant owner finds an innovative way to save their business during an economic downturn.",
475
  "An environmental scientist tracks mysterious wildlife behavior that reveals concerning climate changes.",
476
  "A community comes together to rebuild after a devastating natural disaster.",
 
 
 
 
 
477
  ]
478
 
479
  # Create the Gradio interface
480
+ with gr.Blocks(title="AI Story Video Generator", theme=gr.themes.Soft()) as demo:
481
  gr.Markdown("# 🎬 AI Story Video Generator")
482
  gr.Markdown("Enter a one-sentence prompt to generate a complete story with video and narration.")
483
 
 
538
  - Include interesting characters, settings, or situations
539
  - Make your prompt realistic but with potential for development
540
  - Try to suggest a potential conflict or discovery
541
+
542
+ ## Note on Processing Time
543
+
544
+ For faster testing, the app currently processes only the first 5 sentences of the story.
545
+ In a production environment, this limit would be removed.
546
 
547
  ## Troubleshooting
548