|
import gradio as gr |
|
from openai import OpenAI |
|
from PIL import Image |
|
import base64 |
|
import io |
|
from gtts import gTTS |
|
|
|
def pil_to_base64(image, max_size=1024): |
|
""" |
|
Convert a PIL image to a base64 string, resizing if necessary. |
|
""" |
|
if max(image.size) > max_size: |
|
image.thumbnail((max_size, max_size)) |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="JPEG") |
|
return base64.b64encode(buffered.getvalue()).decode() |
|
|
|
def generate_initial_story(image, genre, api_key): |
|
""" |
|
Generate the initial travel story based on the image and genre. |
|
""" |
|
if not image or not genre or not api_key: |
|
return "Please provide all inputs.", [] |
|
|
|
image_base64 = pil_to_base64(image) |
|
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key) |
|
prompt = f"Generate a {genre} story based on this travel photo." |
|
messages = [ |
|
{"role": "user", "content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}]} |
|
] |
|
try: |
|
completion = client.chat.completions.create(model="meta-llama/llama-3.2-11b-vision-instruct:free", messages=messages) |
|
story = completion.choices[0].message.content |
|
messages.append({"role": "assistant", "content": story}) |
|
return story, messages |
|
except Exception as e: |
|
return f"Error: {str(e)}", [] |
|
|
|
def generate_continuation(continuation_prompt, messages, api_key): |
|
""" |
|
Generate a continuation of the story based on the provided prompt. |
|
""" |
|
if not continuation_prompt or not messages or not api_key: |
|
return "Please provide a continuation prompt.", messages |
|
|
|
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key) |
|
new_user_message = {"role": "user", "content": continuation_prompt} |
|
updated_messages = messages + [new_user_message] |
|
try: |
|
completion = client.chat.completions.create(model="meta-llama/llama-3.2-11b-vision-instruct:free", messages=updated_messages) |
|
continuation = completion.choices[0].message.content |
|
updated_messages.append({"role": "assistant", "content": continuation}) |
|
full_story = "\n\n".join([msg["content"] for msg in updated_messages if msg["role"] == "assistant"]) |
|
return full_story, updated_messages |
|
except Exception as e: |
|
return f"Error: {str(e)}", messages |
|
|
|
def generate_audio(story): |
|
""" |
|
Generate an audio file from the story text using gTTS. |
|
""" |
|
if not story: |
|
return None |
|
tts = gTTS(text=story, lang='en') |
|
audio_file = "story.mp3" |
|
tts.save(audio_file) |
|
return audio_file |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Interactive Travel Story Generator") |
|
gr.Markdown("Upload a travel photo, select a genre, and provide your OpenRouter API key to generate a personalized travel story.") |
|
gr.Markdown("After generating a story, enter a prompt below to continue it, or click 'Generate Audio' to hear it!") |
|
gr.Markdown("Note: You need an OpenRouter API key from [OpenRouter](https://openrouter.ai/).") |
|
|
|
with gr.Row(): |
|
image_input = gr.Image(type="pil", label="Upload Travel Photo") |
|
genre_input = gr.Textbox(label="Story Genre (e.g., adventure, romance, mystery)") |
|
api_key_input = gr.Textbox(label="OpenRouter API Key", type="password") |
|
|
|
generate_button = gr.Button("Generate Story") |
|
story_output = gr.Textbox(label="Generated Story", lines=10) |
|
|
|
with gr.Row(): |
|
tts_button = gr.Button("Generate Audio") |
|
audio_output = gr.Audio(label="Story Audio") |
|
|
|
continuation_prompt = gr.Textbox(label="Continuation Prompt (e.g., 'Now, the hero finds a mysterious map.')") |
|
continue_button = gr.Button("Continue Story") |
|
|
|
message_state = gr.State([]) |
|
|
|
def on_generate_story(image, genre, api_key): |
|
story, messages = generate_initial_story(image, genre, api_key) |
|
return story, messages |
|
|
|
generate_button.click( |
|
fn=on_generate_story, |
|
inputs=[image_input, genre_input, api_key_input], |
|
outputs=[story_output, message_state] |
|
) |
|
|
|
def on_generate_continuation(continuation_prompt, message_state, api_key): |
|
full_story, updated_messages = generate_continuation(continuation_prompt, message_state, api_key) |
|
return full_story, updated_messages |
|
|
|
continue_button.click( |
|
fn=on_generate_continuation, |
|
inputs=[continuation_prompt, message_state, api_key_input], |
|
outputs=[story_output, message_state] |
|
) |
|
|
|
def on_generate_audio(story): |
|
audio_file = generate_audio(story) |
|
return audio_file |
|
|
|
tts_button.click( |
|
fn=on_generate_audio, |
|
inputs=story_output, |
|
outputs=audio_output |
|
) |
|
|
|
demo.launch() |