mayf commited on
Commit
c4110d1
·
verified ·
1 Parent(s): 799d95f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -71
app.py CHANGED
@@ -1,86 +1,125 @@
1
- import os
2
- import time
3
  import streamlit as st
4
- from PIL import Image
5
- from transformers import pipeline
6
  from gtts import gTTS
7
- import tempfile
8
- from llama_cpp import Llama
9
-
10
- # First install required package:
11
- # pip install llama-cpp-python
12
-
13
- # —––––––– Page Setup —–––––––
14
- st.set_page_config(page_title="Magic Story Generator", layout="centered")
15
- st.title("📖✨ Turn Images into Children's Stories")
16
 
17
- # —––––––– Load Models (cached) —–––––––
18
- @st.cache_resource(show_spinner=False)
19
  def load_models():
20
- # 1) Image captioning model
21
- captioner = pipeline(
22
- "image-to-text",
23
- model="Salesforce/blip-image-captioning-base"
 
24
  )
25
 
26
- # 2) GGUF Story Model
27
- storyteller = Llama(
28
- model_path="DavidAU/L3-Grand-Story-Darkness-MOE-4X8-24.9B-e32-GGUF",
29
- n_ctx=2048,
30
- n_threads=4,
31
- n_gpu_layers=0 # Set based on your GPU capacity
32
  )
33
- return captioner, storyteller
 
34
 
35
- captioner, storyteller = load_models()
 
 
 
 
 
 
 
 
36
 
37
- # —––––––– Main App —–––––––
38
- uploaded = st.file_uploader("Upload an image:", type=["jpg", "png", "jpeg"])
39
- if uploaded:
40
- img = Image.open(uploaded).convert("RGB")
41
- st.image(img, use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Generate caption
44
- with st.spinner("🔍 Generating caption..."):
45
- cap = captioner(img)
46
- caption = cap[0]['generated_text']
47
- st.success(f"**Caption:** {caption}")
48
 
49
- # Generate story
50
- prompt = f"""Below is an image description. Write a children's story based on it.
 
51
 
52
- Image Description: {caption}
53
- Story:"""
54
 
55
- with st.spinner("📝 Crafting magical story..."):
56
- start = time.time()
57
- output = storyteller(
58
- prompt=prompt,
59
- max_tokens=500,
60
- temperature=0.7,
61
- top_p=0.9,
62
- repeat_penalty=1.1
63
- )
64
- gen_time = time.time() - start
65
- story = output['choices'][0]['text'].strip()
66
- st.text(f"⏱ Generated in {gen_time:.1f}s")
67
-
68
- # Post-process story
69
- story = story.split("###")[0].strip() # Remove any trailing artifacts
70
 
71
- # Display story
72
- st.subheader("📚 Your Magical Story")
73
- st.write(story)
74
-
75
- # Audio conversion
76
- with st.spinner("🔊 Converting to audio..."):
77
- try:
78
- tts = gTTS(text=story, lang="en", slow=False)
79
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
80
- tts.save(tmp.name)
81
- st.audio(tmp.name, format="audio/mp3")
82
- except Exception as e:
83
- st.warning(f"⚠️ Audio conversion failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # Footer
86
- st.markdown("---\n*Made with ❤️ by your friendly story wizard*")
 
1
+ # app.py
 
2
  import streamlit as st
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
4
  from gtts import gTTS
5
+ import os
6
+ import time
7
+ import torch
8
+ from threading import Thread
 
 
 
 
 
9
 
10
+ # Initialize models
11
+ @st.cache_resource
12
  def load_models():
13
+ model_name = "Qwen/Qwen3-1.7B"
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(
16
+ model_name,
17
+ trust_remote_code=True
18
  )
19
 
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_name,
22
+ torch_dtype="auto",
23
+ device_map="auto",
24
+ trust_remote_code=True
 
25
  )
26
+
27
+ return model, tokenizer
28
 
29
+ def parse_thinking_output(output_ids, tokenizer, thinking_token_id=151668):
30
+ try:
31
+ index = len(output_ids) - output_ids[::-1].index(thinking_token_id)
32
+ except ValueError:
33
+ index = 0
34
+
35
+ thinking = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
36
+ content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
37
+ return thinking, content
38
 
39
+ def generate_response(prompt, model, tokenizer):
40
+ messages = [{"role": "user", "content": prompt}]
41
+ text = tokenizer.apply_chat_template(
42
+ messages,
43
+ tokenize=False,
44
+ add_generation_prompt=True,
45
+ enable_thinking=True
46
+ )
47
+
48
+ streamer = TextIteratorStreamer(tokenizer)
49
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
50
+
51
+ generation_kwargs = dict(
52
+ **inputs,
53
+ streamer=streamer,
54
+ max_new_tokens=4096,
55
+ temperature=0.7,
56
+ do_sample=True
57
+ )
58
+
59
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
60
+ thread.start()
61
+
62
+ full_response = ""
63
+ thinking_content = ""
64
+ for new_text in streamer:
65
+ full_response += new_text
66
+ try:
67
+ current_ids = tokenizer.encode(full_response, return_tensors="pt")[0]
68
+ thinking, content = parse_thinking_output(current_ids, tokenizer)
69
+ yield thinking, content
70
+ except:
71
+ yield "", full_response
72
 
73
+ def text_to_speech(text):
74
+ tts = gTTS(text=text, lang='en', slow=False)
75
+ audio_file = f"audio_{int(time.time())}.mp3"
76
+ tts.save(audio_file)
77
+ return audio_file
78
 
79
+ # Streamlit UI
80
+ def main():
81
+ st.title("🧠 Qwen3-1.7B Thinking Mode Demo")
82
 
83
+ model, tokenizer = load_models()
 
84
 
85
+ with st.sidebar:
86
+ st.header("Settings")
87
+ max_length = st.slider("Max Tokens", 100, 4096, 1024)
88
+ temperature = st.slider("Temperature", 0.1, 1.0, 0.7)
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ prompt = st.text_area("Enter your prompt:",
91
+ "Explain quantum computing in simple terms")
92
+
93
+ if st.button("Generate Response"):
94
+ with st.spinner("Generating response..."):
95
+ # Setup containers
96
+ thinking_container = st.container(border=True)
97
+ response_container = st.empty()
98
+ audio_container = st.empty()
99
+
100
+ full_content = ""
101
+ current_thinking = ""
102
+
103
+ for thinking, content in generate_response(prompt, model, tokenizer):
104
+ if thinking != current_thinking:
105
+ thinking_container.markdown(f"**Thinking Process:**\n{thinking}")
106
+ current_thinking = thinking
107
+
108
+ if content != full_content:
109
+ response_container.markdown(f"**Final Answer:**\n{content}")
110
+ full_content = content
111
+
112
+ # Add audio version
113
+ audio_file = text_to_speech(full_content)
114
+ audio_container.audio(audio_file, format='audio/mp3')
115
+
116
+ # Add download button
117
+ st.download_button(
118
+ label="Download Response",
119
+ data=full_content,
120
+ file_name="qwen_response.txt",
121
+ mime="text/plain"
122
+ )
123
 
124
+ if __name__ == "__main__":
125
+ main()