awacke1 commited on
Commit
1e91094
ยท
verified ยท
1 Parent(s): 4775701

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -0
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
3
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline
4
+ import torch
5
+ from PIL import Image
6
+ import numpy as np
7
+ import os
8
+ import tempfile
9
+ import moviepy.editor as mpe
10
+ import nltk
11
+ from pydub import AudioSegment
12
+ import warnings
13
+ import asyncio
14
+ import edge_tts
15
+ import random
16
+ from datetime import datetime
17
+ import pytz
18
+ import re
19
+ import json
20
+ from gradio_client import Client
21
+
22
+ warnings.filterwarnings("ignore", category=UserWarning)
23
+
24
+ # Ensure NLTK data is downloaded
25
+ nltk.download('punkt')
26
+
27
+ # Initialize clients
28
+ arxiv_client = None
29
+ def init_arxiv_client():
30
+ global arxiv_client
31
+ if arxiv_client is None:
32
+ arxiv_client = Client("awacke1/Arxiv-Paper-Search-And-QA-RAG-Pattern")
33
+ return arxiv_client
34
+
35
+ # File I/O Functions
36
+ def generate_filename(prompt, timestamp=None):
37
+ """Generate a safe filename from prompt and timestamp"""
38
+ if timestamp is None:
39
+ timestamp = datetime.now(pytz.UTC).strftime("%Y%m%d_%H%M%S")
40
+ # Clean the prompt to create a safe filename
41
+ safe_prompt = re.sub(r'[^\w\s-]', '', prompt)[:50].strip()
42
+ return f"story_{timestamp}_{safe_prompt}.txt"
43
+
44
+ def save_story(story, prompt, filename=None):
45
+ """Save story to file with metadata"""
46
+ if filename is None:
47
+ filename = generate_filename(prompt)
48
+
49
+ try:
50
+ with open(filename, 'w', encoding='utf-8') as f:
51
+ metadata = {
52
+ 'timestamp': datetime.now().isoformat(),
53
+ 'prompt': prompt,
54
+ 'type': 'story'
55
+ }
56
+ f.write(json.dumps(metadata) + '\n---\n' + story)
57
+ return filename
58
+ except Exception as e:
59
+ print(f"Error saving story: {e}")
60
+ return None
61
+
62
+ def load_story(filename):
63
+ """Load story and metadata from file"""
64
+ try:
65
+ with open(filename, 'r', encoding='utf-8') as f:
66
+ content = f.read()
67
+ parts = content.split('\n---\n')
68
+ if len(parts) == 2:
69
+ metadata = json.loads(parts[0])
70
+ story = parts[1]
71
+ return metadata, story
72
+ return None, content
73
+ except Exception as e:
74
+ print(f"Error loading story: {e}")
75
+ return None, None
76
+
77
+ # Story Generation Functions
78
+ def generate_story(prompt, model_choice):
79
+ """Generate story using specified model"""
80
+ try:
81
+ client = init_arxiv_client()
82
+ if client is None:
83
+ return "Error: Story generation service is not available."
84
+
85
+ result = client.predict(
86
+ prompt=prompt,
87
+ llm_model_picked=model_choice,
88
+ stream_outputs=True,
89
+ api_name="/ask_llm"
90
+ )
91
+ return result
92
+ except Exception as e:
93
+ return f"Error generating story: {str(e)}"
94
+
95
+ async def generate_speech(text, voice="en-US-AriaNeural"):
96
+ """Generate speech from text"""
97
+ try:
98
+ communicate = edge_tts.Communicate(text, voice)
99
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
100
+ tmp_path = tmp_file.name
101
+ await communicate.save(tmp_path)
102
+ return tmp_path
103
+ except Exception as e:
104
+ print(f"Error in text2speech: {str(e)}")
105
+ return None
106
+
107
+ def process_story_and_audio(prompt, model_choice):
108
+ """Process story and generate audio"""
109
+ try:
110
+ # Generate story
111
+ story = generate_story(prompt, model_choice)
112
+ if isinstance(story, str) and story.startswith("Error"):
113
+ return story, None, None
114
+
115
+ # Save story
116
+ filename = save_story(story, prompt)
117
+
118
+ # Generate audio
119
+ audio_path = asyncio.run(generate_speech(story))
120
+
121
+ return story, audio_path, filename
122
+ except Exception as e:
123
+ return f"Error: {str(e)}", None, None
124
+
125
+ # Main App Code (your existing code remains here)
126
+ # LLM Inference Class and other existing classes remain unchanged
127
+ class LLMInferenceNode:
128
+ # Your existing LLMInferenceNode implementation
129
+ pass
130
+
131
+ # Initialize models (your existing initialization code remains here)
132
+ device = "cuda" if torch.cuda.is_available() else "cpu"
133
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
134
+
135
+ # Story generator
136
+ story_generator = pipeline(
137
+ 'text-generation',
138
+ model='gpt2-large',
139
+ device=0 if device == 'cuda' else -1
140
+ )
141
+
142
+ # Stable Diffusion model
143
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
144
+ "runwayml/stable-diffusion-v1-5",
145
+ torch_dtype=torch_dtype
146
+ ).to(device)
147
+
148
+ # Create the enhanced Gradio interface
149
+ with gr.Blocks() as demo:
150
+ gr.Markdown("""# ๐ŸŽจ AI Creative Suite
151
+ Generate videos, stories, and more with AI!
152
+ """)
153
+
154
+ with gr.Tabs():
155
+ # Your existing video generation tab
156
+ with gr.Tab("Video Generation"):
157
+ with gr.Row():
158
+ with gr.Column():
159
+ prompt_input = gr.Textbox(label="Enter a Prompt", lines=2)
160
+ generate_button = gr.Button("Generate Video")
161
+ with gr.Column():
162
+ video_output = gr.Video(label="Generated Video")
163
+
164
+ generate_button.click(fn=process_pipeline, inputs=prompt_input, outputs=video_output)
165
+
166
+ # New story generation tab
167
+ with gr.Tab("Story Generation"):
168
+ with gr.Row():
169
+ with gr.Column():
170
+ story_prompt = gr.Textbox(
171
+ label="Story Concept",
172
+ placeholder="Enter your story idea...",
173
+ lines=3
174
+ )
175
+ model_choice = gr.Dropdown(
176
+ label="Model",
177
+ choices=[
178
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
179
+ "mistralai/Mistral-7B-Instruct-v0.2"
180
+ ],
181
+ value="mistralai/Mixtral-8x7B-Instruct-v0.1"
182
+ )
183
+ generate_story_btn = gr.Button("Generate Story")
184
+
185
+ with gr.Row():
186
+ story_output = gr.Textbox(
187
+ label="Generated Story",
188
+ lines=10,
189
+ interactive=False
190
+ )
191
+
192
+ with gr.Row():
193
+ audio_output = gr.Audio(
194
+ label="Story Narration",
195
+ type="filepath"
196
+ )
197
+ filename_output = gr.Textbox(
198
+ label="Saved Filename",
199
+ interactive=False
200
+ )
201
+
202
+ generate_story_btn.click(
203
+ fn=process_story_and_audio,
204
+ inputs=[story_prompt, model_choice],
205
+ outputs=[story_output, audio_output, filename_output]
206
+ )
207
+
208
+ # File management section
209
+ with gr.Row():
210
+ file_list = gr.Dropdown(
211
+ label="Saved Stories",
212
+ choices=[f for f in os.listdir() if f.startswith("story_") and f.endswith(".txt")],
213
+ interactive=True
214
+ )
215
+ refresh_btn = gr.Button("๐Ÿ”„ Refresh")
216
+
217
+ def refresh_files():
218
+ return gr.Dropdown(choices=[f for f in os.listdir() if f.startswith("story_") and f.endswith(".txt")])
219
+
220
+ refresh_btn.click(fn=refresh_files, outputs=[file_list])
221
+
222
+ # Launch the app
223
+ if __name__ == "__main__":
224
+ demo.launch(debug=True)