mayf commited on
Commit
8151df4
·
verified ·
1 Parent(s): 613c57d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -45
app.py CHANGED
@@ -1,64 +1,131 @@
 
1
  import streamlit as st
2
- import torch
 
 
 
 
 
 
 
 
 
3
  from PIL import Image
4
  from gtts import gTTS
5
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
6
 
7
- # Streamlit config must be first
8
- st.set_page_config(page_title="Magic Story Generator", layout="centered", page_icon="📖")
9
 
10
- # Model loading cached for performance
11
- @st.cache_resource
12
  def load_models():
13
- caption_model = pipeline("image-to-text", "Salesforce/blip-image-captioning-base")
14
- story_model = AutoModelForCausalLM.from_pretrained(
15
- "Qwen/Qwen3-1.7B",
 
 
 
 
 
 
 
 
16
  device_map="auto",
17
- torch_dtype=torch.float16,
18
- trust_remote_code=True
 
 
 
 
 
19
  )
20
- story_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B", trust_remote_code=True)
21
- return caption_model, story_model, story_tokenizer
22
 
23
- # Initialize models
24
- caption_pipe, story_model, story_tokenizer = load_models()
25
 
26
- # Main app interface
27
- st.title("📖 Instant Story Generator")
28
- uploaded_image = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
 
 
29
 
30
  if uploaded_image:
31
- img = Image.open(uploaded_image).convert("RGB")
32
- st.image(img, caption="Your Image", use_column_width=True)
33
-
 
34
  # Generate caption
35
- caption = caption_pipe(img)[0]['generated_text']
 
 
36
 
37
- # Generate story
38
- messages = [{
39
- "role": "system",
40
- "content": f"Create a 50 to 100 words children's story based on: {caption}."
41
- }]
42
 
43
- inputs = story_tokenizer.apply_chat_template(
44
- messages,
45
- return_tensors="pt"
46
- ).to(story_model.device)
47
-
48
- outputs = story_model.generate(
49
- inputs,
50
- max_new_tokens=300,
51
- temperature=0.7,
52
- top_p=0.9
53
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # Display results
56
- story = story_tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
57
- st.subheader("Generated Story")
58
- st.write(story)
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # Audio conversion
61
- audio = gTTS(text=story, lang='en')
62
- with tempfile.NamedTemporaryFile(delete=False) as fp:
63
- audio.save(fp.name)
64
- st.audio(fp.name, format='audio/mp3')
 
 
 
 
 
 
 
 
 
1
+ # Must be FIRST import and FIRST Streamlit command
2
  import streamlit as st
3
+ st.set_page_config(
4
+ page_title="Magic Story Generator",
5
+ layout="centered",
6
+ page_icon="📖"
7
+ )
8
+
9
+ # Other imports AFTER Streamlit config
10
+ import re
11
+ import time
12
+ import tempfile
13
  from PIL import Image
14
  from gtts import gTTS
15
+ from transformers import pipeline
16
 
17
+ # --- Constants & Setup ---
18
+ st.title("📖✨ Turn Images into Children's Stories")
19
 
20
+ # --- Model Loading (Cached) ---
21
+ @st.cache_resource(show_spinner=False)
22
  def load_models():
23
+ # Image captioning model
24
+ captioner = pipeline(
25
+ "image-to-text",
26
+ model="Salesforce/blip-image-captioning-base",
27
+ device=-1 # Use -1 for CPU, 0 for GPU
28
+ )
29
+
30
+ # Story generation model (Qwen3-1.7B)
31
+ storyteller = pipeline(
32
+ "text-generation",
33
+ model="Qwen/Qwen3-1.7B",
34
  device_map="auto",
35
+ trust_remote_code=True,
36
+ torch_dtype="auto",
37
+ max_new_tokens=250,
38
+ temperature=0.7,
39
+ top_p=0.85,
40
+ repetition_penalty=1.15,
41
+ eos_token_id=151645
42
  )
43
+
44
+ return captioner, storyteller
45
 
46
+ caption_pipe, story_pipe = load_models()
 
47
 
48
+ # --- Main Application Flow ---
49
+ uploaded_image = st.file_uploader(
50
+ "Upload a children's book style image:",
51
+ type=["jpg", "jpeg", "png"]
52
+ )
53
 
54
  if uploaded_image:
55
+ # Process image
56
+ image = Image.open(uploaded_image).convert("RGB")
57
+ st.image(image, use_container_width=True)
58
+
59
  # Generate caption
60
+ with st.spinner("🔍 Analyzing image..."):
61
+ caption_result = caption_pipe(image)
62
+ image_caption = caption_result[0].get("generated_text", "").strip()
63
 
64
+ if not image_caption:
65
+ st.error("❌ Couldn't understand this image. Please try another!")
66
+ st.stop()
 
 
67
 
68
+ st.success(f"**Image Understanding:** {image_caption}")
69
+
70
+ # Create story prompt
71
+ story_prompt = (
72
+ f"<|im_start|>system\n"
73
+ f"You are a children's book author. Create a 100-150 word story based on: {image_caption}\n"
 
 
 
 
74
  )
75
+
76
+ # Generate story
77
+ with st.spinner("📝 Crafting magical story..."):
78
+ start_time = time.time()
79
+ story_result = story_pipe(
80
+ story_prompt,
81
+ do_sample=True,
82
+ num_return_sequences=1,
83
+ pad_token_id=151645
84
+ )
85
+ generation_time = time.time() - start_time
86
+
87
+ # Process output
88
+ raw_story = story_result[0]['generated_text']
89
 
90
+ # Clean up story text
91
+ clean_story = raw_story.split("<|im_start|>assistant\n")[-1]
92
+ clean_story = clean_story.split("<|im_start|>")[0] # Remove any new turns
93
+ clean_story = clean_story.replace("<|im_end|>", "").strip()
94
 
95
+ # Remove assistant mentions using regex
96
+ clean_story = re.sub(
97
+ r'^(assistant[:>]?\s*)+',
98
+ '',
99
+ clean_story,
100
+ flags=re.IGNORECASE
101
+ ).strip()
102
+
103
+ # Format story punctuation
104
+ final_story = []
105
+ for sentence in clean_story.split(". "):
106
+ sentence = sentence.strip()
107
+ if not sentence:
108
+ continue
109
+ if not sentence.endswith('.'):
110
+ sentence += '.'
111
+ final_story.append(sentence[0].upper() + sentence[1:])
112
+
113
+ final_story = " ".join(final_story).replace("..", ".")[:800]
114
+
115
+ # Display story
116
+ st.subheader("✨ Your Magical Story")
117
+ st.write(final_story)
118
+
119
  # Audio conversion
120
+ with st.spinner("🔊 Creating audio version..."):
121
+ try:
122
+ audio = gTTS(text=final_story, lang="en", slow=False)
123
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
124
+ audio.save(tmp_file.name)
125
+ st.audio(tmp_file.name, format="audio/mp3")
126
+ except Exception as e:
127
+ st.error(f"❌ Audio conversion failed: {str(e)}")
128
+
129
+ # Footer
130
+ st.markdown("---")
131
+ st.markdown("📚 Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")