1
File size: 2,015 Bytes
4d1f328
613c57d
c83a777
 
613c57d
c83a777
613c57d
 
c83a777
613c57d
 
982555a
613c57d
 
 
c4110d1
613c57d
 
6adb177
613c57d
 
c83a777
613c57d
 
982555a
613c57d
 
 
fd1d947
982555a
613c57d
 
982555a
613c57d
 
4d1f328
613c57d
 
 
 
 
e537b6d
613c57d
 
 
 
e537b6d
613c57d
 
 
 
 
 
c4110d1
613c57d
 
 
 
4d1f328
613c57d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
import torch
from PIL import Image
from gtts import gTTS
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

# Streamlit config must be first
st.set_page_config(page_title="Magic Story Generator", layout="centered", page_icon="📖")

# Model loading cached for performance
@st.cache_resource
def load_models():
    caption_model = pipeline("image-to-text", "Salesforce/blip-image-captioning-base")
    story_model = AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen3-1.7B",
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
    story_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B", trust_remote_code=True)
    return caption_model, story_model, story_tokenizer

# Initialize models
caption_pipe, story_model, story_tokenizer = load_models()

# Main app interface
st.title("📖 Instant Story Generator")
uploaded_image = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])

if uploaded_image:
    img = Image.open(uploaded_image).convert("RGB")
    st.image(img, caption="Your Image", use_column_width=True)
    
    # Generate caption
    caption = caption_pipe(img)[0]['generated_text']
    
    # Generate story
    messages = [{
        "role": "system",
        "content": f"Create a 50 to 100 words children's story based on: {caption}."
    }]
    
    inputs = story_tokenizer.apply_chat_template(
        messages,
        return_tensors="pt"
    ).to(story_model.device)
    
    outputs = story_model.generate(
        inputs,
        max_new_tokens=300,
        temperature=0.7,
        top_p=0.9
    )
    
    # Display results
    story = story_tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
    st.subheader("Generated Story")
    st.write(story)
    
    # Audio conversion
    audio = gTTS(text=story, lang='en')
    with tempfile.NamedTemporaryFile(delete=False) as fp:
        audio.save(fp.name)
        st.audio(fp.name, format='audio/mp3')