File size: 4,548 Bytes
4ac9df6 971fbd1 9957949 971fbd1 d9760ae dd36385 2aecac8 971fbd1 dd36385 4ac9df6 f030151 4ac9df6 d9760ae 4ac9df6 d9760ae 971fbd1 c376a77 971fbd1 d9760ae 971fbd1 d9760ae 971fbd1 d9760ae 971fbd1 d9760ae 971fbd1 2aecac8 4ac9df6 971fbd1 dd36385 971fbd1 d850ee8 2aecac8 d850ee8 d9760ae 971fbd1 d9760ae dd36385 d9760ae 971fbd1 d9760ae 971fbd1 d9760ae 971fbd1 d9760ae 971fbd1 d9760ae 971fbd1 d9760ae d850ee8 971fbd1 d9760ae 4ac9df6 d9760ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, Any
# Configure model (updated for local execution)
DEFAULT_SYSTEM_PROMPT = """You are a friendly Assistant. Provide clear, accurate, and brief answers.
Keep responses polite, engaging, and to the point. If unsure, politely suggest alternatives."""
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" # Directly specify model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Page configuration
st.set_page_config(
page_title="DeepSeek-AI R1",
page_icon="🤖",
layout="centered"
)
def initialize_session_state():
"""Initialize all session state variables"""
if "messages" not in st.session_state:
st.session_state.messages = []
if "model_loaded" not in st.session_state:
st.session_state.update({
"model_loaded": False,
"model": None,
"tokenizer": None
})
def load_model():
"""Load model and tokenizer with quantization"""
if not st.session_state.model_loaded:
with st.spinner("Loading model (this may take a minute)..."):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto"
)
st.session_state.update({
"model": model,
"tokenizer": tokenizer,
"model_loaded": True
})
def configure_sidebar() -> Dict[str, Any]:
"""Create sidebar components"""
with st.sidebar:
st.header("Configuration")
return {
"system_message": st.text_area("System Message", value=DEFAULT_SYSTEM_PROMPT, height=100),
"max_tokens": st.slider("Max Tokens", 10, 4000, 512),
"temperature": st.slider("Temperature", 0.1, 1.0, 0.7),
"top_p": st.slider("Top-p", 0.1, 1.0, 0.9)
}
def format_prompt(system_message: str, user_input: str) -> str:
"""Format prompt according to model's required template"""
return f"""<|begin_of_sentence|>System: {system_message}
<|User|>{user_input}<|Assistant|>"""
def generate_response(prompt: str, settings: Dict[str, Any]) -> str:
"""Generate response using local model"""
inputs = st.session_state.tokenizer(prompt, return_tensors="pt").to(DEVICE)
outputs = st.session_state.model.generate(
inputs.input_ids,
max_new_tokens=settings["max_tokens"],
temperature=settings["temperature"],
top_p=settings["top_p"],
pad_token_id=st.session_state.tokenizer.eos_token_id
)
response = st.session_state.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("\n</think>\n")[0].strip()
response = response.replace("<|User|>", "").strip()
response = response.replace("<|System|>", "").strip()
return response.split("<|Assistant|>")[-1].strip()
def handle_chat_interaction(settings: Dict[str, Any]):
"""Manage chat interactions"""
if prompt := st.chat_input("Type your message..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
try:
with st.spinner("Generating response..."):
full_prompt = format_prompt(
settings["system_message"],
prompt
)
response = generate_response(full_prompt, settings)
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
except Exception as e:
st.error(f"Generation error: {str(e)}")
def display_chat_history():
"""Display chat history"""
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
def main():
initialize_session_state()
load_model() # Load model before anything else
settings = configure_sidebar()
st.title("🤖 DeepSeek Chat")
st.caption(f"Running {MODEL_NAME} directly on {DEVICE.upper()}")
display_chat_history()
handle_chat_interaction(settings)
if __name__ == "__main__":
main()
|