Spaces:
Runtime error
Runtime error
import gradio as gr | |
from unsloth import FastLanguageModel | |
import torch | |
# Load the pre-trained language model and tokenizer | |
model_name = "suhaif/unsloth-llama-3-8b-4bit" | |
max_seq_length = 2048 | |
dtype = None | |
load_in_4bit = True | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=model_name, | |
max_seq_length=max_seq_length, | |
dtype=dtype, | |
load_in_4bit=load_in_4bit | |
) | |
# Default instruction for generating the story | |
default_instruction = "You are a creative writer. Based on the given input, generate a well-structured story with an engaging plot, well-developed characters, and immersive details. Ensure the story has a clear beginning, middle, and end. Include dialogue and descriptions to bring the story to life. You can add twist to the story also" | |
# Function to format the prompt | |
def format_prompt(input_text, instruction=default_instruction): | |
return f"{instruction}\n\nInput:\n{input_text}\n\nResponse:\n" | |
# Function to generate story from the model | |
def generate_story(user_input): | |
# Format the input | |
prompt = format_prompt(user_input) | |
inputs = tokenizer([prompt], return_tensors="pt").to("cuda") | |
# Generate output from the model | |
outputs = model.generate(**inputs, max_new_tokens=500, use_cache=True) | |
# Decode and return the result | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Feedback mechanism (collects and stores feedback) | |
feedback_data = [] | |
def submit_feedback(rating, feedback_text, story): | |
feedback_data.append({ | |
"rating": rating, | |
"feedback_text": feedback_text, | |
"story": story | |
}) | |
return "Thank you for your feedback!" | |
# Community engagement feature - to upload and share stories | |
shared_stories = [] | |
def share_story(title, story_text): | |
shared_stories.append({"title": title, "story_text": story_text}) | |
return f"Story '{title}' has been shared successfully!" | |
def display_stories(): | |
return [(story['title'], story['story_text']) for story in shared_stories] | |
# Gradio interface | |
def storytelling_interface(): | |
# User inputs | |
with gr.Blocks() as demo: | |
gr.Markdown("# Interactive Storytelling Assistant") | |
with gr.Row(): | |
with gr.Column(): | |
user_input = gr.Textbox(label="Enter your story prompt", placeholder="A young adventurer embarks on a journey to find a lost treasure...", lines=4) | |
generate_button = gr.Button("Generate Story") | |
story_output = gr.Textbox(label="Generated Story", placeholder="Generated story will appear here...", lines=10, interactive=False) | |
generate_button.click(fn=generate_story, inputs=user_input, outputs=story_output) | |
with gr.Column(): | |
gr.Markdown("## Provide Feedback") | |
rating = gr.Slider(1, 5, step=1, label="Rate the story") | |
feedback_text = gr.Textbox(label="Feedback", placeholder="Provide any suggestions or comments...", lines=3) | |
submit_feedback_button = gr.Button("Submit Feedback") | |
submit_feedback_button.click(fn=submit_feedback, inputs=[rating, feedback_text, story_output], outputs=None) | |
with gr.Row(): | |
gr.Markdown("## Share your Story") | |
title = gr.Textbox(label="Story Title", placeholder="Enter the title of your story") | |
story_text = gr.Textbox(label="Your Story", placeholder="Enter your full story here...", lines=8) | |
share_button = gr.Button("Share Story") | |
share_button.click(fn=share_story, inputs=[title, story_text], outputs=None) | |
with gr.Row(): | |
gr.Markdown("## Browse Shared Stories") | |
stories_list = gr.Dropdown(display_stories, label="Select a story to read") | |
story_display = gr.Textbox(label="Story Content", lines=10, interactive=False) | |
stories_list.change(fn=lambda title: next(story['story_text'] for story in shared_stories if story['title'] == title), inputs=stories_list, outputs=story_display) | |
demo.launch() | |
# Start the storytelling interface | |
storytelling_interface() | |