mayf commited on
Commit
c83a777
·
verified ·
1 Parent(s): 154acfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -108
app.py CHANGED
@@ -1,125 +1,115 @@
1
- # app.py
2
- import streamlit as st
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
- from gtts import gTTS
5
  import os
6
  import time
7
- import torch
8
- from threading import Thread
 
 
 
9
 
10
- # Initialize models
11
- @st.cache_resource
12
- def load_models():
13
- model_name = "Qwen/Qwen3-1.7B"
14
-
15
- tokenizer = AutoTokenizer.from_pretrained(
16
- model_name,
17
- trust_remote_code=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  )
19
 
20
- model = AutoModelForCausalLM.from_pretrained(
21
- model_name,
22
- torch_dtype="auto",
 
23
  device_map="auto",
24
- trust_remote_code=True
 
 
 
 
 
 
25
  )
26
 
27
- return model, tokenizer
 
 
28
 
29
- def parse_thinking_output(output_ids, tokenizer, thinking_token_id=151668):
30
- try:
31
- index = len(output_ids) - output_ids[::-1].index(thinking_token_id)
32
- except ValueError:
33
- index = 0
34
-
35
- thinking = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
36
- content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
37
- return thinking, content
38
 
39
- def generate_response(prompt, model, tokenizer):
40
- messages = [{"role": "user", "content": prompt}]
41
- text = tokenizer.apply_chat_template(
42
- messages,
43
- tokenize=False,
44
- add_generation_prompt=True,
45
- enable_thinking=True
 
 
 
 
 
 
 
 
 
 
46
  )
47
 
48
- streamer = TextIteratorStreamer(tokenizer)
49
- inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
50
 
51
- generation_kwargs = dict(
52
- **inputs,
53
- streamer=streamer,
54
- max_new_tokens=4096,
55
- temperature=0.7,
56
- do_sample=True
57
- )
58
 
59
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
60
- thread.start()
61
-
62
- full_response = ""
63
- thinking_content = ""
64
- for new_text in streamer:
65
- full_response += new_text
66
- try:
67
- current_ids = tokenizer.encode(full_response, return_tensors="pt")[0]
68
- thinking, content = parse_thinking_output(current_ids, tokenizer)
69
- yield thinking, content
70
- except:
71
- yield "", full_response
72
 
73
- def text_to_speech(text):
74
- tts = gTTS(text=text, lang='en', slow=False)
75
- audio_file = f"audio_{int(time.time())}.mp3"
76
- tts.save(audio_file)
77
- return audio_file
78
 
79
- # Streamlit UI
80
- def main():
81
- st.title("🧠 Qwen3-1.7B Thinking Mode Demo")
82
-
83
- model, tokenizer = load_models()
84
-
85
- with st.sidebar:
86
- st.header("Settings")
87
- max_length = st.slider("Max Tokens", 100, 4096, 1024)
88
- temperature = st.slider("Temperature", 0.1, 1.0, 0.7)
89
-
90
- prompt = st.text_area("Enter your prompt:",
91
- "Explain quantum computing in simple terms")
92
-
93
- if st.button("Generate Response"):
94
- with st.spinner("Generating response..."):
95
- # Setup containers
96
- thinking_container = st.container(border=True)
97
- response_container = st.empty()
98
- audio_container = st.empty()
99
-
100
- full_content = ""
101
- current_thinking = ""
102
-
103
- for thinking, content in generate_response(prompt, model, tokenizer):
104
- if thinking != current_thinking:
105
- thinking_container.markdown(f"**Thinking Process:**\n{thinking}")
106
- current_thinking = thinking
107
-
108
- if content != full_content:
109
- response_container.markdown(f"**Final Answer:**\n{content}")
110
- full_content = content
111
-
112
- # Add audio version
113
- audio_file = text_to_speech(full_content)
114
- audio_container.audio(audio_file, format='audio/mp3')
115
-
116
- # Add download button
117
- st.download_button(
118
- label="Download Response",
119
- data=full_content,
120
- file_name="qwen_response.txt",
121
- mime="text/plain"
122
- )
123
 
124
- if __name__ == "__main__":
125
- main()
 
 
 
 
 
1
  import os
2
  import time
3
+ import streamlit as st
4
+ from PIL import Image
5
+ from transformers import pipeline
6
+ from gtts import gTTS
7
+ import tempfile
8
 
9
+ # --- Requirements ---
10
+ # Update requirements.txt to include:
11
+ """
12
+ streamlit>=1.20
13
+ pillow>=9.0
14
+ torch>=2.0.0
15
+ transformers>=4.40
16
+ sentencepiece>=0.2.0
17
+ gTTS>=2.3.1
18
+ accelerate>=0.30
19
+ """
20
+
21
+ # --- Page Setup ---
22
+ st.set_page_config(page_title="Magic Story Generator", layout="centered")
23
+ st.title("📖✨ Turn Images into Children's Stories")
24
+
25
+ # --- Load Pipelines (cached) ---
26
+ @st.cache_resource(show_spinner=False)
27
+ def load_pipelines():
28
+ # 1) Image-captioning pipeline (BLIP)
29
+ captioner = pipeline(
30
+ task="image-to-text",
31
+ model="Salesforce/blip-image-captioning-base",
32
+ device=-1
33
  )
34
 
35
+ # 2) Modified story-generation pipeline using Qwen3-1.7B
36
+ storyteller = pipeline(
37
+ task="text-generation",
38
+ model="Qwen/Qwen3-1.7B",
39
  device_map="auto",
40
+ trust_remote_code=True,
41
+ torch_dtype="auto",
42
+ max_new_tokens=150,
43
+ temperature=0.7,
44
+ top_p=0.9,
45
+ repetition_penalty=1.2,
46
+ eos_token_id=151645 # Specific to Qwen3 tokenizer
47
  )
48
 
49
+ return captioner, storyteller
50
+
51
+ captioner, storyteller = load_pipelines()
52
 
53
+ # --- Main App ---
54
+ uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
55
+ if uploaded:
56
+ # Load and display the image
57
+ img = Image.open(uploaded).convert("RGB")
58
+ st.image(img, use_container_width=True)
 
 
 
59
 
60
+ # Generate caption
61
+ with st.spinner("🔍 Generating caption..."):
62
+ cap = captioner(img)
63
+ caption = cap[0].get("generated_text", "").strip() if isinstance(cap, list) else ""
64
+ if not caption:
65
+ st.error("😢 Couldn't understand this image. Try another one!")
66
+ st.stop()
67
+ st.success(f"**Caption:** {caption}")
68
+
69
+ # Build prompt and generate story
70
+ prompt = (
71
+ f"<|im_start|>system\n"
72
+ f"You are a children's story writer. Create a 50-100 word story based on this image description: {caption}\n"
73
+ f"<|im_end|>\n"
74
+ f"<|im_start|>user\n"
75
+ f"Write a coherent, child-friendly story that flows naturally with simple vocabulary.<|im_end|>\n"
76
+ f"<|im_start|>assistant\n"
77
  )
78
 
79
+ with st.spinner("📝 Writing story..."):
80
+ start = time.time()
81
+ out = storyteller(
82
+ prompt,
83
+ do_sample=True,
84
+ num_return_sequences=1
85
+ )
86
+ gen_time = time.time() - start
87
+ st.text(f"⏱ Generated in {gen_time:.1f}s")
88
 
89
+ # Process output
90
+ story = out[0]['generated_text'].split("<|im_start|>assistant\n")[-1]
91
+ story = story.replace("<|im_end|>", "").strip()
 
 
 
 
92
 
93
+ # Enforce ≤100 words and proper ending
94
+ words = story.split()
95
+ if len(words) > 100:
96
+ story = " ".join(words[:100])
97
+ if not story.endswith(('.', '!', '?')):
98
+ story += '.'
 
 
 
 
 
 
 
99
 
100
+ # Display story
101
+ st.subheader("📚 Your Magical Story")
102
+ st.write(story)
 
 
103
 
104
+ # Convert to audio
105
+ with st.spinner("🔊 Converting to audio..."):
106
+ try:
107
+ tts = gTTS(text=story, lang="en", slow=False)
108
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
109
+ tts.save(tmp.name)
110
+ st.audio(tmp.name, format="audio/mp3")
111
+ except Exception as e:
112
+ st.warning(f"⚠️ TTS failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ # Footer
115
+ st.markdown("---\nMade with ❤️ by your friendly story wizard")