awacke1 commited on
Commit
96d7e87
·
verified ·
1 Parent(s): e2a70ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -166
app.py CHANGED
@@ -1,84 +1,31 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
3
- from diffusers import StableDiffusionPipeline, DiffusionPipeline
4
- import torch
5
- from PIL import Image
6
- import numpy as np
7
- import os
8
- import tempfile
9
- import moviepy.editor as mpe
10
- import nltk
11
- from pydub import AudioSegment
12
- import warnings
13
- import asyncio
14
- import edge_tts
15
  import random
16
  from datetime import datetime
 
 
 
 
 
 
17
  import pytz
18
  import re
19
  import json
20
- from gradio_client import Client
21
 
22
- warnings.filterwarnings("ignore", category=UserWarning)
23
 
24
- # Ensure NLTK data is downloaded
25
- nltk.download('punkt')
26
-
27
- # Initialize clients
28
  arxiv_client = None
29
- def init_arxiv_client():
 
30
  global arxiv_client
31
  if arxiv_client is None:
32
  arxiv_client = Client("awacke1/Arxiv-Paper-Search-And-QA-RAG-Pattern")
33
  return arxiv_client
34
 
35
- # File I/O Functions
36
- def generate_filename(prompt, timestamp=None):
37
- """Generate a safe filename from prompt and timestamp"""
38
- if timestamp is None:
39
- timestamp = datetime.now(pytz.UTC).strftime("%Y%m%d_%H%M%S")
40
- # Clean the prompt to create a safe filename
41
- safe_prompt = re.sub(r'[^\w\s-]', '', prompt)[:50].strip()
42
- return f"story_{timestamp}_{safe_prompt}.txt"
43
-
44
- def save_story(story, prompt, filename=None):
45
- """Save story to file with metadata"""
46
- if filename is None:
47
- filename = generate_filename(prompt)
48
-
49
- try:
50
- with open(filename, 'w', encoding='utf-8') as f:
51
- metadata = {
52
- 'timestamp': datetime.now().isoformat(),
53
- 'prompt': prompt,
54
- 'type': 'story'
55
- }
56
- f.write(json.dumps(metadata) + '\n---\n' + story)
57
- return filename
58
- except Exception as e:
59
- print(f"Error saving story: {e}")
60
- return None
61
-
62
- def load_story(filename):
63
- """Load story and metadata from file"""
64
- try:
65
- with open(filename, 'r', encoding='utf-8') as f:
66
- content = f.read()
67
- parts = content.split('\n---\n')
68
- if len(parts) == 2:
69
- metadata = json.loads(parts[0])
70
- story = parts[1]
71
- return metadata, story
72
- return None, content
73
- except Exception as e:
74
- print(f"Error loading story: {e}")
75
- return None, None
76
-
77
- # Story Generation Functions
78
  def generate_story(prompt, model_choice):
79
  """Generate story using specified model"""
80
  try:
81
- client = init_arxiv_client()
82
  if client is None:
83
  return "Error: Story generation service is not available."
84
 
@@ -110,115 +57,58 @@ def process_story_and_audio(prompt, model_choice):
110
  # Generate story
111
  story = generate_story(prompt, model_choice)
112
  if isinstance(story, str) and story.startswith("Error"):
113
- return story, None, None
114
-
115
- # Save story
116
- filename = save_story(story, prompt)
117
-
118
  # Generate audio
119
  audio_path = asyncio.run(generate_speech(story))
120
 
121
- return story, audio_path, filename
122
  except Exception as e:
123
- return f"Error: {str(e)}", None, None
124
-
125
- # Main App Code (your existing code remains here)
126
- # LLM Inference Class and other existing classes remain unchanged
127
- class LLMInferenceNode:
128
- # Your existing LLMInferenceNode implementation
129
- pass
130
-
131
- # Initialize models (your existing initialization code remains here)
132
- device = "cuda" if torch.cuda.is_available() else "cpu"
133
- torch_dtype = torch.float16 if device == "cuda" else torch.float32
134
 
135
- # Story generator
136
- story_generator = pipeline(
137
- 'text-generation',
138
- model='gpt2-large',
139
- device=0 if device == 'cuda' else -1
140
- )
141
-
142
- # Stable Diffusion model
143
- sd_pipe = StableDiffusionPipeline.from_pretrained(
144
- "runwayml/stable-diffusion-v1-5",
145
- torch_dtype=torch_dtype
146
- ).to(device)
147
-
148
- # Create the enhanced Gradio interface
149
- with gr.Blocks() as demo:
150
- gr.Markdown("""# 🎨 AI Creative Suite
151
- Generate videos, stories, and more with AI!
152
  """)
153
 
154
- with gr.Tabs():
155
- # Your existing video generation tab
156
- with gr.Tab("Video Generation"):
157
- with gr.Row():
158
- with gr.Column():
159
- prompt_input = gr.Textbox(label="Enter a Prompt", lines=2)
160
- generate_button = gr.Button("Generate Video")
161
- with gr.Column():
162
- video_output = gr.Video(label="Generated Video")
163
-
164
- generate_button.click(fn=process_pipeline, inputs=prompt_input, outputs=video_output)
165
-
166
- # New story generation tab
167
- with gr.Tab("Story Generation"):
168
- with gr.Row():
169
- with gr.Column():
170
- story_prompt = gr.Textbox(
171
- label="Story Concept",
172
- placeholder="Enter your story idea...",
173
- lines=3
174
- )
175
- model_choice = gr.Dropdown(
176
- label="Model",
177
- choices=[
178
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
179
- "mistralai/Mistral-7B-Instruct-v0.2"
180
- ],
181
- value="mistralai/Mixtral-8x7B-Instruct-v0.1"
182
- )
183
- generate_story_btn = gr.Button("Generate Story")
184
-
185
- with gr.Row():
186
- story_output = gr.Textbox(
187
- label="Generated Story",
188
- lines=10,
189
- interactive=False
190
- )
191
-
192
- with gr.Row():
193
- audio_output = gr.Audio(
194
- label="Story Narration",
195
- type="filepath"
196
- )
197
- filename_output = gr.Textbox(
198
- label="Saved Filename",
199
- interactive=False
200
- )
201
-
202
- generate_story_btn.click(
203
- fn=process_story_and_audio,
204
- inputs=[story_prompt, model_choice],
205
- outputs=[story_output, audio_output, filename_output]
206
  )
207
-
208
- # File management section
209
- with gr.Row():
210
- file_list = gr.Dropdown(
211
- label="Saved Stories",
212
- choices=[f for f in os.listdir() if f.startswith("story_") and f.endswith(".txt")],
213
- interactive=True
214
- )
215
- refresh_btn = gr.Button("🔄 Refresh")
216
-
217
- def refresh_files():
218
- return gr.Dropdown(choices=[f for f in os.listdir() if f.startswith("story_") and f.endswith(".txt")])
219
-
220
- refresh_btn.click(fn=refresh_files, outputs=[file_list])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- # Launch the app
223
  if __name__ == "__main__":
224
- demo.launch(debug=True)
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import random
3
  from datetime import datetime
4
+ import tempfile
5
+ import os
6
+ import edge_tts
7
+ import asyncio
8
+ import warnings
9
+ from gradio_client import Client
10
  import pytz
11
  import re
12
  import json
 
13
 
14
+ warnings.filterwarnings('ignore')
15
 
16
+ # Initialize client outside of interface definition
 
 
 
17
  arxiv_client = None
18
+
19
+ def init_client():
20
  global arxiv_client
21
  if arxiv_client is None:
22
  arxiv_client = Client("awacke1/Arxiv-Paper-Search-And-QA-RAG-Pattern")
23
  return arxiv_client
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def generate_story(prompt, model_choice):
26
  """Generate story using specified model"""
27
  try:
28
+ client = init_client()
29
  if client is None:
30
  return "Error: Story generation service is not available."
31
 
 
57
  # Generate story
58
  story = generate_story(prompt, model_choice)
59
  if isinstance(story, str) and story.startswith("Error"):
60
+ return story, None
61
+
 
 
 
62
  # Generate audio
63
  audio_path = asyncio.run(generate_speech(story))
64
 
65
+ return story, audio_path
66
  except Exception as e:
67
+ return f"Error: {str(e)}", None
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Create the Gradio interface
70
+ with gr.Blocks(title="AI Story Generator") as demo:
71
+ gr.Markdown("""
72
+ # 🎭 AI Story Generator & Narrator
73
+ Generate creative stories and listen to them!
 
 
 
 
 
 
 
 
 
 
 
 
74
  """)
75
 
76
+ with gr.Row():
77
+ with gr.Column():
78
+ prompt_input = gr.Textbox(
79
+ label="Story Concept",
80
+ placeholder="Enter your story idea...",
81
+ lines=3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  )
83
+ model_choice = gr.Dropdown(
84
+ label="Model",
85
+ choices=[
86
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
87
+ "mistralai/Mistral-7B-Instruct-v0.2"
88
+ ],
89
+ value="mistralai/Mixtral-8x7B-Instruct-v0.1"
90
+ )
91
+ generate_btn = gr.Button("Generate Story")
92
+
93
+ with gr.Row():
94
+ story_output = gr.Textbox(
95
+ label="Generated Story",
96
+ lines=10,
97
+ interactive=False
98
+ )
99
+
100
+ with gr.Row():
101
+ audio_output = gr.Audio(
102
+ label="Story Narration",
103
+ type="filepath"
104
+ )
105
+
106
+ generate_btn.click(
107
+ fn=process_story_and_audio,
108
+ inputs=[prompt_input, model_choice],
109
+ outputs=[story_output, audio_output]
110
+ )
111
 
112
+ # Launch the app using the current pattern
113
  if __name__ == "__main__":
114
+ demo.launch()