|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
st.title("π§ββοΈ Magic Story Buddy π") |
|
st.markdown("Let's create a magical story just for you!") |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model = AutoModelForCausalLM.from_pretrained("ajibawa-2023/Young-Children-Storyteller-Mistral-7B", torch_dtype=torch.float16) |
|
tokenizer = AutoTokenizer.from_pretrained("ajibawa-2023/Young-Children-Storyteller-Mistral-7B") |
|
return model, tokenizer |
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
child_name = st.text_input("What's your name, young storyteller?") |
|
story_theme = st.selectbox("What would you like your story to be about?", |
|
["Space Adventure", "Magical Forest", "Underwater World", "Dinosaur Discovery"]) |
|
|
|
|
|
story_length = st.slider("How long should the story be?", 50, 200, 100) |
|
include_moral = st.checkbox("Include a moral lesson?") |
|
|
|
def generate_story(prompt, max_length=500): |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
outputs = model.generate(**inputs, max_length=max_length, num_return_sequences=1, do_sample=True, temperature=0.7) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
if st.button("Create My Story!"): |
|
if child_name and story_theme: |
|
|
|
prompt = f"""Create a short children's story with the following details: |
|
- Main character: {child_name} |
|
- Theme: {story_theme} |
|
- Length: About {story_length} words |
|
- Audience: Children aged 5-10 |
|
- Tone: Friendly, educational, and imaginative |
|
|
|
Story: |
|
Once upon a time, in a {story_theme.lower()}, there was a brave child named {child_name}. """ |
|
|
|
if include_moral: |
|
prompt += "This story teaches us that " |
|
|
|
|
|
story = generate_story(prompt, max_length=story_length) |
|
|
|
|
|
st.markdown("## Your Magical Story") |
|
st.write(story) |
|
|
|
|
|
st.balloons() |
|
else: |
|
st.warning("Please tell me your name and choose a story theme.") |
|
|
|
|
|
st.markdown("---") |
|
st.markdown("π Remember, you're the star of every story! π") |