|
|
|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
from gtts import gTTS |
|
import os |
|
import time |
|
import torch |
|
from threading import Thread |
|
|
|
|
|
@st.cache_resource |
|
def load_models(): |
|
model_name = "Qwen/Qwen3-1.7B" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype="auto", |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
return model, tokenizer |
|
|
|
def parse_thinking_output(output_ids, tokenizer, thinking_token_id=151668): |
|
try: |
|
index = len(output_ids) - output_ids[::-1].index(thinking_token_id) |
|
except ValueError: |
|
index = 0 |
|
|
|
thinking = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n") |
|
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") |
|
return thinking, content |
|
|
|
def generate_response(prompt, model, tokenizer): |
|
messages = [{"role": "user", "content": prompt}] |
|
text = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
enable_thinking=True |
|
) |
|
|
|
streamer = TextIteratorStreamer(tokenizer) |
|
inputs = tokenizer([text], return_tensors="pt").to(model.device) |
|
|
|
generation_kwargs = dict( |
|
**inputs, |
|
streamer=streamer, |
|
max_new_tokens=4096, |
|
temperature=0.7, |
|
do_sample=True |
|
) |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
full_response = "" |
|
thinking_content = "" |
|
for new_text in streamer: |
|
full_response += new_text |
|
try: |
|
current_ids = tokenizer.encode(full_response, return_tensors="pt")[0] |
|
thinking, content = parse_thinking_output(current_ids, tokenizer) |
|
yield thinking, content |
|
except: |
|
yield "", full_response |
|
|
|
def text_to_speech(text): |
|
tts = gTTS(text=text, lang='en', slow=False) |
|
audio_file = f"audio_{int(time.time())}.mp3" |
|
tts.save(audio_file) |
|
return audio_file |
|
|
|
|
|
def main(): |
|
st.title("🧠 Qwen3-1.7B Thinking Mode Demo") |
|
|
|
model, tokenizer = load_models() |
|
|
|
with st.sidebar: |
|
st.header("Settings") |
|
max_length = st.slider("Max Tokens", 100, 4096, 1024) |
|
temperature = st.slider("Temperature", 0.1, 1.0, 0.7) |
|
|
|
prompt = st.text_area("Enter your prompt:", |
|
"Explain quantum computing in simple terms") |
|
|
|
if st.button("Generate Response"): |
|
with st.spinner("Generating response..."): |
|
|
|
thinking_container = st.container(border=True) |
|
response_container = st.empty() |
|
audio_container = st.empty() |
|
|
|
full_content = "" |
|
current_thinking = "" |
|
|
|
for thinking, content in generate_response(prompt, model, tokenizer): |
|
if thinking != current_thinking: |
|
thinking_container.markdown(f"**Thinking Process:**\n{thinking}") |
|
current_thinking = thinking |
|
|
|
if content != full_content: |
|
response_container.markdown(f"**Final Answer:**\n{content}") |
|
full_content = content |
|
|
|
|
|
audio_file = text_to_speech(full_content) |
|
audio_container.audio(audio_file, format='audio/mp3') |
|
|
|
|
|
st.download_button( |
|
label="Download Response", |
|
data=full_content, |
|
file_name="qwen_response.txt", |
|
mime="text/plain" |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|