awacke1 commited on
Commit
55c19e3
·
verified ·
1 Parent(s): 6a858a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -56
app.py CHANGED
@@ -1,6 +1,7 @@
 
1
  import gradio as gr
2
- from transformers import pipeline
3
- from diffusers import StableDiffusionPipeline
4
  import torch
5
  from PIL import Image
6
  import numpy as np
@@ -12,12 +13,69 @@ from pydub import AudioSegment
12
  import warnings
13
  import asyncio
14
  import edge_tts
 
 
15
 
16
  warnings.filterwarnings("ignore", category=UserWarning)
17
 
18
  # Ensure NLTK data is downloaded
19
  nltk.download('punkt')
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Initialize models
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
@@ -30,14 +88,19 @@ story_generator = pipeline(
30
  )
31
 
32
  # Stable Diffusion model
33
- sd_model_id = "runwayml/stable-diffusion-v1-5"
34
  sd_pipe = StableDiffusionPipeline.from_pretrained(
35
- sd_model_id,
36
  torch_dtype=torch_dtype
37
- )
38
- sd_pipe = sd_pipe.to(device)
39
 
40
  # Text-to-Speech function using edge_tts
 
 
 
 
 
 
 
41
  def text2speech(text):
42
  try:
43
  output_path = asyncio.run(_text2speech_async(text))
@@ -46,13 +109,6 @@ def text2speech(text):
46
  print(f"Error in text2speech: {str(e)}")
47
  raise
48
 
49
- async def _text2speech_async(text):
50
- communicate = edge_tts.Communicate(text, voice="en-US-AriaNeural")
51
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
52
- tmp_path = tmp_file.name
53
- await communicate.save(tmp_path)
54
- return tmp_path
55
-
56
  def generate_story(prompt):
57
  generated = story_generator(prompt, max_length=500, num_return_sequences=1)
58
  story = generated[0]['generated_text']
@@ -66,7 +122,6 @@ def generate_images(sentences):
66
  images = []
67
  for idx, sentence in enumerate(sentences):
68
  image = sd_pipe(sentence).images[0]
69
- # Save image to temporary file
70
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{idx}.png")
71
  image.save(temp_file.name)
72
  images.append(temp_file.name)
@@ -75,72 +130,38 @@ def generate_images(sentences):
75
  def generate_audio(story_text):
76
  audio_path = text2speech(story_text)
77
  audio = AudioSegment.from_file(audio_path)
78
- total_duration = len(audio) / 1000 # duration in seconds
79
  return audio_path, total_duration
80
 
81
  def compute_sentence_durations(sentences, total_duration):
82
  total_words = sum(len(sentence.split()) for sentence in sentences)
83
- sentence_durations = []
84
- for sentence in sentences:
85
- num_words = len(sentence.split())
86
- duration = total_duration * (num_words / total_words)
87
- sentence_durations.append(duration)
88
- return sentence_durations
89
 
90
  def create_video(images, durations, audio_path):
91
- clips = []
92
- for image_path, duration in zip(images, durations):
93
- clip = mpe.ImageClip(image_path).set_duration(duration)
94
- clips.append(clip)
95
  video = mpe.concatenate_videoclips(clips, method='compose')
96
  audio = mpe.AudioFileClip(audio_path)
97
  video = video.set_audio(audio)
98
- # Save video
99
  output_path = os.path.join(tempfile.gettempdir(), "final_video.mp4")
100
  video.write_videofile(output_path, fps=1, codec='libx264')
101
  return output_path
102
-
103
  def process_pipeline(prompt, progress=gr.Progress()):
104
  try:
105
- total_steps = 6
106
- step = 0
107
-
108
- progress(step / total_steps, desc="Generating Story")
109
  story = generate_story(prompt)
110
- step += 1
111
-
112
- progress(step / total_steps, desc="Splitting Story into Sentences")
113
  sentences = split_story_into_sentences(story)
114
- step += 1
115
-
116
- progress(step / total_steps, desc="Generating Images for Sentences")
117
  images = generate_images(sentences)
118
- step += 1
119
-
120
- progress(step / total_steps, desc="Generating Audio")
121
  audio_path, total_duration = generate_audio(story)
122
- step += 1
123
-
124
- progress(step / total_steps, desc="Computing Durations")
125
  durations = compute_sentence_durations(sentences, total_duration)
126
- step += 1
127
-
128
- progress(step / total_steps, desc="Creating Video")
129
  video_path = create_video(images, durations, audio_path)
130
- step += 1
131
-
132
- progress(1.0, desc="Completed")
133
-
134
  return video_path
135
  except Exception as e:
136
  print(f"Error in process_pipeline: {str(e)}")
137
  raise gr.Error(f"An error occurred: {str(e)}")
138
-
 
139
  title = """<h1 align="center">AI Story Video Generator 🎥</h1>
140
- <p align="center">
141
- Generate a story from a prompt, create images for each sentence, and produce a video with narration!
142
- </p>
143
- """
144
 
145
  with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
146
  gr.HTML(title)
@@ -154,4 +175,4 @@ with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
154
 
155
  generate_button.click(fn=process_pipeline, inputs=prompt_input, outputs=video_output)
156
 
157
- demo.launch(debug=True)
 
1
+ # app.py
2
  import gradio as gr
3
+ from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
4
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline
5
  import torch
6
  from PIL import Image
7
  import numpy as np
 
13
  import warnings
14
  import asyncio
15
  import edge_tts
16
+ import random
17
+ from openai import OpenAI
18
 
19
  warnings.filterwarnings("ignore", category=UserWarning)
20
 
21
  # Ensure NLTK data is downloaded
22
  nltk.download('punkt')
23
 
24
+ # LLM Inference Class
25
+ class LLMInferenceNode:
26
+ def __init__(self):
27
+ self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
28
+ self.sambanova_api_key = os.getenv("SAMBANOVA_API_KEY")
29
+
30
+ self.huggingface_client = OpenAI(
31
+ base_url="https://api-inference.huggingface.co/v1/",
32
+ api_key=self.huggingface_token,
33
+ )
34
+ self.sambanova_client = OpenAI(
35
+ api_key=self.sambanova_api_key,
36
+ base_url="https://api.sambanova.ai/v1",
37
+ )
38
+
39
+ def generate(self, input_text, long_talk=True, compress=False,
40
+ compression_level="medium", poster=False, prompt_type="Short",
41
+ provider="Hugging Face", model=None):
42
+ try:
43
+ # Define system message
44
+ system_message = "You are a helpful assistant. Try your best to give the best response possible to the user."
45
+
46
+ # Define base prompts based on type
47
+ prompts = {
48
+ "Short": """Create a brief, straightforward caption for this description, suitable for a text-to-image AI system.
49
+ Focus on the main elements, key characters, and overall scene without elaborate details.""",
50
+ "Long": """Create a detailed visually descriptive caption of this description for a text-to-image AI system.
51
+ Include detailed visual descriptions, cinematography, and lighting setup."""
52
+ }
53
+
54
+ base_prompt = prompts.get(prompt_type, prompts["Short"])
55
+ user_message = f"{base_prompt}\nDescription: {input_text}"
56
+
57
+ # Generate with selected provider
58
+ if provider == "Hugging Face":
59
+ client = self.huggingface_client
60
+ else:
61
+ client = self.sambanova_client
62
+
63
+ response = client.chat.completions.create(
64
+ model=model or "meta-llama/Meta-Llama-3.1-70B-Instruct",
65
+ max_tokens=1024,
66
+ temperature=1.0,
67
+ messages=[
68
+ {"role": "system", "content": system_message},
69
+ {"role": "user", "content": user_message},
70
+ ]
71
+ )
72
+
73
+ return response.choices[0].message.content.strip()
74
+
75
+ except Exception as e:
76
+ print(f"An error occurred: {e}")
77
+ return f"Error occurred while processing the request: {str(e)}"
78
+
79
  # Initialize models
80
  device = "cuda" if torch.cuda.is_available() else "cpu"
81
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
 
88
  )
89
 
90
  # Stable Diffusion model
 
91
  sd_pipe = StableDiffusionPipeline.from_pretrained(
92
+ "runwayml/stable-diffusion-v1-5",
93
  torch_dtype=torch_dtype
94
+ ).to(device)
 
95
 
96
  # Text-to-Speech function using edge_tts
97
+ async def _text2speech_async(text):
98
+ communicate = edge_tts.Communicate(text, voice="en-US-AriaNeural")
99
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
100
+ tmp_path = tmp_file.name
101
+ await communicate.save(tmp_path)
102
+ return tmp_path
103
+
104
  def text2speech(text):
105
  try:
106
  output_path = asyncio.run(_text2speech_async(text))
 
109
  print(f"Error in text2speech: {str(e)}")
110
  raise
111
 
 
 
 
 
 
 
 
112
  def generate_story(prompt):
113
  generated = story_generator(prompt, max_length=500, num_return_sequences=1)
114
  story = generated[0]['generated_text']
 
122
  images = []
123
  for idx, sentence in enumerate(sentences):
124
  image = sd_pipe(sentence).images[0]
 
125
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{idx}.png")
126
  image.save(temp_file.name)
127
  images.append(temp_file.name)
 
130
  def generate_audio(story_text):
131
  audio_path = text2speech(story_text)
132
  audio = AudioSegment.from_file(audio_path)
133
+ total_duration = len(audio) / 1000
134
  return audio_path, total_duration
135
 
136
  def compute_sentence_durations(sentences, total_duration):
137
  total_words = sum(len(sentence.split()) for sentence in sentences)
138
+ return [total_duration * (len(sentence.split()) / total_words) for sentence in sentences]
 
 
 
 
 
139
 
140
  def create_video(images, durations, audio_path):
141
+ clips = [mpe.ImageClip(img).set_duration(dur) for img, dur in zip(images, durations)]
 
 
 
142
  video = mpe.concatenate_videoclips(clips, method='compose')
143
  audio = mpe.AudioFileClip(audio_path)
144
  video = video.set_audio(audio)
 
145
  output_path = os.path.join(tempfile.gettempdir(), "final_video.mp4")
146
  video.write_videofile(output_path, fps=1, codec='libx264')
147
  return output_path
148
+
149
  def process_pipeline(prompt, progress=gr.Progress()):
150
  try:
 
 
 
 
151
  story = generate_story(prompt)
 
 
 
152
  sentences = split_story_into_sentences(story)
 
 
 
153
  images = generate_images(sentences)
 
 
 
154
  audio_path, total_duration = generate_audio(story)
 
 
 
155
  durations = compute_sentence_durations(sentences, total_duration)
 
 
 
156
  video_path = create_video(images, durations, audio_path)
 
 
 
 
157
  return video_path
158
  except Exception as e:
159
  print(f"Error in process_pipeline: {str(e)}")
160
  raise gr.Error(f"An error occurred: {str(e)}")
161
+
162
+ # Gradio Interface
163
  title = """<h1 align="center">AI Story Video Generator 🎥</h1>
164
+ <p align="center">Generate a story from a prompt, create images for each sentence, and produce a video with narration!</p>"""
 
 
 
165
 
166
  with gr.Blocks(css=".container { max-width: 800px; margin: auto; }") as demo:
167
  gr.HTML(title)
 
175
 
176
  generate_button.click(fn=process_pipeline, inputs=prompt_input, outputs=video_output)
177
 
178
+ demo.launch(debug=True)