mayf commited on
Commit
613c57d
·
verified ·
1 Parent(s): 618f9ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -126
app.py CHANGED
@@ -1,144 +1,64 @@
1
- # story_generator.py
2
- import re
3
- import time
4
- import tempfile
5
  import streamlit as st
 
6
  from PIL import Image
7
  from gtts import gTTS
8
- from transformers import pipeline
9
 
10
- # --- Initialize Streamlit Config ---
11
- st.set_page_config(
12
- page_title="Magic Story Generator",
13
- layout="centered",
14
- page_icon="📖"
15
- )
16
 
17
- # --- Model Loading (Cached) ---
18
- @st.cache_resource(show_spinner=False)
19
  def load_models():
20
- # Image captioning model
21
- captioner = pipeline(
22
- "image-to-text",
23
- model="Salesforce/blip-image-captioning-base",
24
- device=-1 # Force CPU usage
25
- )
26
-
27
- # Story generation model with updated parameters
28
- storyteller = pipeline(
29
- "text-generation",
30
- model="Qwen/Qwen3-1.7B",
31
  device_map="auto",
32
- trust_remote_code=True,
33
- torch_dtype="auto",
34
- temperature=0.7,
35
- top_p=0.9,
36
- repetition_penalty=1.1,
37
- pad_token_id=151645,
38
- max_new_tokens=300
39
  )
40
-
41
- return captioner, storyteller
42
 
43
- # --- Text Processing Utilities ---
44
- def clean_generated_text(raw_text):
45
- # Split at first assistant marker
46
- clean_text = raw_text.split("<|im_start|>assistant\n", 1)[-1]
47
-
48
- # Remove any subsequent chat turns
49
- clean_text = clean_text.split("<|im_start|>")[0]
50
-
51
- # Remove special tokens and whitespace
52
- clean_text = clean_text.replace("<|im_end|>", "").strip()
53
-
54
- # Regex cleanup for remaining markers
55
- clean_text = re.sub(
56
- r'^(assistant[\s\-\:>]*)+',
57
- '',
58
- clean_text,
59
- flags=re.IGNORECASE
60
- ).strip()
61
-
62
- # Format punctuation and capitalization
63
- sentences = []
64
- for sent in re.split(r'(?<=[.!?]) +', clean_text):
65
- sent = sent.strip()
66
- if not sent:
67
- continue
68
- if sent[-1] not in {'.', '!', '?'}:
69
- sent += '.'
70
- sentences.append(sent[0].upper() + sent[1:])
71
-
72
- return ' '.join(sentences)
73
-
74
- # --- Main Application UI ---
75
- st.title("📖✨ Magic Story Generator")
76
 
77
- uploaded_image = st.file_uploader(
78
- "Upload a children's book style image:",
79
- type=["jpg", "jpeg", "png"]
80
- )
81
 
82
  if uploaded_image:
83
- # Display uploaded image with modern parameter
84
- image = Image.open(uploaded_image).convert("RGB")
85
- st.image(image, use_container_width=True) # Updated parameter
86
 
87
- # Load models only when needed
88
- try:
89
- caption_pipe, story_pipe = load_models()
90
- except Exception as e:
91
- st.error(f"❌ Model loading failed: {str(e)}")
92
- st.stop()
93
 
94
- # Generate image caption
95
- with st.spinner("🔍 Analyzing image..."):
96
- try:
97
- caption_result = caption_pipe(image)
98
- image_caption = caption_result[0].get("generated_text", "").strip()
99
-
100
- if not image_caption:
101
- raise ValueError("Empty caption generated")
102
-
103
- st.success(f"**Image Understanding:** {image_caption}")
104
- except Exception as e:
105
- st.error("❌ Image analysis failed. Please try another image.")
106
- st.stop()
107
-
108
- # Create story prompt
109
- story_prompt = (
110
- f"<|im_start|>system\n"
111
- f"You are a children's book author. Create a 150-word story based on: {image_caption}\n"
112
- )
113
 
114
- # Generate story text
115
- with st.spinner("📝 Crafting magical story..."):
116
- try:
117
- story_result = story_pipe(
118
- story_prompt,
119
- num_return_sequences=1
120
- )
121
- raw_story = story_result[0]['generated_text']
122
- except Exception as e:
123
- st.error("❌ Story generation failed. Please try again.")
124
- st.stop()
125
 
126
- # Process and display story
127
- final_story = clean_generated_text(raw_story)
 
 
 
 
128
 
129
- st.subheader("✨ Your Story")
130
- st.write(final_story)
 
 
131
 
132
- # Generate audio version
133
- with st.spinner("🔊 Creating audio version..."):
134
- try:
135
- tts = gTTS(text=final_story, lang='en', slow=False)
136
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
137
- tts.save(fp.name)
138
- st.audio(fp.read(), format="audio/mp3")
139
- except Exception as e:
140
- st.warning("⚠️ Audio conversion failed. Text version still available.")
141
-
142
- # Footer
143
- st.markdown("---")
144
- st.caption("Made with ♥ by The Story Wizard • [Report Issues](https://example.com)")
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
  from PIL import Image
4
  from gtts import gTTS
5
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
6
 
7
+ # Streamlit config must be first
8
+ st.set_page_config(page_title="Magic Story Generator", layout="centered", page_icon="📖")
 
 
 
 
9
 
10
+ # Model loading cached for performance
11
+ @st.cache_resource
12
  def load_models():
13
+ caption_model = pipeline("image-to-text", "Salesforce/blip-image-captioning-base")
14
+ story_model = AutoModelForCausalLM.from_pretrained(
15
+ "Qwen/Qwen3-1.7B",
 
 
 
 
 
 
 
 
16
  device_map="auto",
17
+ torch_dtype=torch.float16,
18
+ trust_remote_code=True
 
 
 
 
 
19
  )
20
+ story_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B", trust_remote_code=True)
21
+ return caption_model, story_model, story_tokenizer
22
 
23
+ # Initialize models
24
+ caption_pipe, story_model, story_tokenizer = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Main app interface
27
+ st.title("📖 Instant Story Generator")
28
+ uploaded_image = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
 
29
 
30
  if uploaded_image:
31
+ img = Image.open(uploaded_image).convert("RGB")
32
+ st.image(img, caption="Your Image", use_column_width=True)
 
33
 
34
+ # Generate caption
35
+ caption = caption_pipe(img)[0]['generated_text']
 
 
 
 
36
 
37
+ # Generate story
38
+ messages = [{
39
+ "role": "system",
40
+ "content": f"Create a 50 to 100 words children's story based on: {caption}."
41
+ }]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ inputs = story_tokenizer.apply_chat_template(
44
+ messages,
45
+ return_tensors="pt"
46
+ ).to(story_model.device)
 
 
 
 
 
 
 
47
 
48
+ outputs = story_model.generate(
49
+ inputs,
50
+ max_new_tokens=300,
51
+ temperature=0.7,
52
+ top_p=0.9
53
+ )
54
 
55
+ # Display results
56
+ story = story_tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
57
+ st.subheader("Generated Story")
58
+ st.write(story)
59
 
60
+ # Audio conversion
61
+ audio = gTTS(text=story, lang='en')
62
+ with tempfile.NamedTemporaryFile(delete=False) as fp:
63
+ audio.save(fp.name)
64
+ st.audio(fp.name, format='audio/mp3')