File size: 3,246 Bytes
c1ca117
fef6f0f
4b3ee30
c1ca117
 
 
 
 
ba9176f
c1ca117
 
de391e9
dfd6986
c1ca117
dfd6986
f02143a
 
 
 
 
 
 
 
de391e9
 
 
dfd6986
f02143a
 
de391e9
 
f02143a
de391e9
 
 
 
 
 
dfd6986
de391e9
dfd6986
c1ca117
de391e9
efd3d3c
c1ca117
 
 
 
 
dfd6986
c1ca117
dfd6986
 
de391e9
 
dfd6986
c1ca117
de391e9
f02143a
c1ca117
de391e9
 
c1ca117
 
 
dfd6986
de391e9
c1ca117
f02143a
 
 
 
 
 
 
 
 
 
 
 
 
c1ca117
f02143a
 
 
 
 
 
 
de391e9
f02143a
 
 
 
 
 
 
 
 
 
 
 
 
c1ca117
 
f02143a
 
 
 
 
de391e9
f02143a
de391e9
f02143a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import random
from datetime import datetime
import tempfile
import os
import edge_tts
import asyncio
import warnings
from gradio_client import Client
import pytz
import re
import json

warnings.filterwarnings('ignore')

# Initialize client outside of interface definition
arxiv_client = None

def init_client():
    global arxiv_client
    if arxiv_client is None:
        arxiv_client = Client("awacke1/Arxiv-Paper-Search-And-QA-RAG-Pattern")
    return arxiv_client

def generate_story(prompt, model_choice):
    """Generate story using specified model"""
    try:
        client = init_client()
        if client is None:
            return "Error: Story generation service is not available."
        
        result = client.predict(
            prompt=prompt,
            llm_model_picked=model_choice,
            stream_outputs=True,
            api_name="/ask_llm"
        )
        return result
    except Exception as e:
        return f"Error generating story: {str(e)}"

async def generate_speech(text, voice="en-US-AriaNeural"):
    """Generate speech from text"""
    try:
        communicate = edge_tts.Communicate(text, voice)
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
            tmp_path = tmp_file.name
            await communicate.save(tmp_path)
        return tmp_path
    except Exception as e:
        print(f"Error in text2speech: {str(e)}")
        return None

def process_story_and_audio(prompt, model_choice):
    """Process story and generate audio"""
    try:
        # Generate story
        story = generate_story(prompt, model_choice)
        if isinstance(story, str) and story.startswith("Error"):
            return story, None

        # Generate audio
        audio_path = asyncio.run(generate_speech(story))
        
        return story, audio_path
    except Exception as e:
        return f"Error: {str(e)}", None

# Create the Gradio interface
with gr.Blocks(title="AI Story Generator") as demo:
    gr.Markdown("""
    # ๐ŸŽญ AI Story Generator & Narrator
    Generate creative stories and listen to them!
    """)
    
    with gr.Row():
        with gr.Column():
            prompt_input = gr.Textbox(
                label="Story Concept",
                placeholder="Enter your story idea...",
                lines=3
            )
            model_choice = gr.Dropdown(
                label="Model",
                choices=[
                    "mistralai/Mixtral-8x7B-Instruct-v0.1",
                    "mistralai/Mistral-7B-Instruct-v0.2"
                ],
                value="mistralai/Mixtral-8x7B-Instruct-v0.1"
            )
            generate_btn = gr.Button("Generate Story")
    
    with gr.Row():
        story_output = gr.Textbox(
            label="Generated Story",
            lines=10,
            interactive=False
        )
    
    with gr.Row():
        audio_output = gr.Audio(
            label="Story Narration",
            type="filepath"
        )
    
    generate_btn.click(
        fn=process_story_and_audio,
        inputs=[prompt_input, model_choice],
        outputs=[story_output, audio_output]
    )

# Launch the app using the current pattern
if __name__ == "__main__":
    demo.launch()