Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM | |
from diffusers import StableDiffusionPipeline, DiffusionPipeline | |
import torch | |
from PIL import Image | |
import numpy as np | |
import os | |
import tempfile | |
import moviepy.editor as mpe | |
import nltk | |
from pydub import AudioSegment | |
import warnings | |
import asyncio | |
import edge_tts | |
import random | |
from datetime import datetime | |
import pytz | |
import re | |
import json | |
from gradio_client import Client | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# Ensure NLTK data is downloaded | |
nltk.download('punkt') | |
# Initialize clients | |
arxiv_client = None | |
def init_arxiv_client(): | |
global arxiv_client | |
if arxiv_client is None: | |
arxiv_client = Client("awacke1/Arxiv-Paper-Search-And-QA-RAG-Pattern") | |
return arxiv_client | |
# File I/O Functions | |
def generate_filename(prompt, timestamp=None): | |
"""Generate a safe filename from prompt and timestamp""" | |
if timestamp is None: | |
timestamp = datetime.now(pytz.UTC).strftime("%Y%m%d_%H%M%S") | |
# Clean the prompt to create a safe filename | |
safe_prompt = re.sub(r'[^\w\s-]', '', prompt)[:50].strip() | |
return f"story_{timestamp}_{safe_prompt}.txt" | |
def save_story(story, prompt, filename=None): | |
"""Save story to file with metadata""" | |
if filename is None: | |
filename = generate_filename(prompt) | |
try: | |
with open(filename, 'w', encoding='utf-8') as f: | |
metadata = { | |
'timestamp': datetime.now().isoformat(), | |
'prompt': prompt, | |
'type': 'story' | |
} | |
f.write(json.dumps(metadata) + '\n---\n' + story) | |
return filename | |
except Exception as e: | |
print(f"Error saving story: {e}") | |
return None | |
def load_story(filename): | |
"""Load story and metadata from file""" | |
try: | |
with open(filename, 'r', encoding='utf-8') as f: | |
content = f.read() | |
parts = content.split('\n---\n') | |
if len(parts) == 2: | |
metadata = json.loads(parts[0]) | |
story = parts[1] | |
return metadata, story | |
return None, content | |
except Exception as e: | |
print(f"Error loading story: {e}") | |
return None, None | |
# Story Generation Functions | |
def generate_story(prompt, model_choice): | |
"""Generate story using specified model""" | |
try: | |
client = init_arxiv_client() | |
if client is None: | |
return "Error: Story generation service is not available." | |
result = client.predict( | |
prompt=prompt, | |
llm_model_picked=model_choice, | |
stream_outputs=True, | |
api_name="/ask_llm" | |
) | |
return result | |
except Exception as e: | |
return f"Error generating story: {str(e)}" | |
async def generate_speech(text, voice="en-US-AriaNeural"): | |
"""Generate speech from text""" | |
try: | |
communicate = edge_tts.Communicate(text, voice) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file: | |
tmp_path = tmp_file.name | |
await communicate.save(tmp_path) | |
return tmp_path | |
except Exception as e: | |
print(f"Error in text2speech: {str(e)}") | |
return None | |
def process_story_and_audio(prompt, model_choice): | |
"""Process story and generate audio""" | |
try: | |
# Generate story | |
story = generate_story(prompt, model_choice) | |
if isinstance(story, str) and story.startswith("Error"): | |
return story, None, None | |
# Save story | |
filename = save_story(story, prompt) | |
# Generate audio | |
audio_path = asyncio.run(generate_speech(story)) | |
return story, audio_path, filename | |
except Exception as e: | |
return f"Error: {str(e)}", None, None | |
# Main App Code (your existing code remains here) | |
# LLM Inference Class and other existing classes remain unchanged | |
class LLMInferenceNode: | |
# Your existing LLMInferenceNode implementation | |
pass | |
# Initialize models (your existing initialization code remains here) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
# Story generator | |
story_generator = pipeline( | |
'text-generation', | |
model='gpt2-large', | |
device=0 if device == 'cuda' else -1 | |
) | |
# Stable Diffusion model | |
sd_pipe = StableDiffusionPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
torch_dtype=torch_dtype | |
).to(device) | |
# Create the enhanced Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("""# ๐จ AI Creative Suite | |
Generate videos, stories, and more with AI! | |
""") | |
with gr.Tabs(): | |
# Your existing video generation tab | |
with gr.Tab("Video Generation"): | |
with gr.Row(): | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Enter a Prompt", lines=2) | |
generate_button = gr.Button("Generate Video") | |
with gr.Column(): | |
video_output = gr.Video(label="Generated Video") | |
generate_button.click(fn=process_pipeline, inputs=prompt_input, outputs=video_output) | |
# New story generation tab | |
with gr.Tab("Story Generation"): | |
with gr.Row(): | |
with gr.Column(): | |
story_prompt = gr.Textbox( | |
label="Story Concept", | |
placeholder="Enter your story idea...", | |
lines=3 | |
) | |
model_choice = gr.Dropdown( | |
label="Model", | |
choices=[ | |
"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"mistralai/Mistral-7B-Instruct-v0.2" | |
], | |
value="mistralai/Mixtral-8x7B-Instruct-v0.1" | |
) | |
generate_story_btn = gr.Button("Generate Story") | |
with gr.Row(): | |
story_output = gr.Textbox( | |
label="Generated Story", | |
lines=10, | |
interactive=False | |
) | |
with gr.Row(): | |
audio_output = gr.Audio( | |
label="Story Narration", | |
type="filepath" | |
) | |
filename_output = gr.Textbox( | |
label="Saved Filename", | |
interactive=False | |
) | |
generate_story_btn.click( | |
fn=process_story_and_audio, | |
inputs=[story_prompt, model_choice], | |
outputs=[story_output, audio_output, filename_output] | |
) | |
# File management section | |
with gr.Row(): | |
file_list = gr.Dropdown( | |
label="Saved Stories", | |
choices=[f for f in os.listdir() if f.startswith("story_") and f.endswith(".txt")], | |
interactive=True | |
) | |
refresh_btn = gr.Button("๐ Refresh") | |
def refresh_files(): | |
return gr.Dropdown(choices=[f for f in os.listdir() if f.startswith("story_") and f.endswith(".txt")]) | |
refresh_btn.click(fn=refresh_files, outputs=[file_list]) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch(debug=True) |