Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -6,11 +6,9 @@ from huggingface_hub import login
|
|
6 |
import re
|
7 |
import os
|
8 |
|
9 |
-
# Load Hugging Face token
|
10 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
11 |
login(token=HF_TOKEN)
|
12 |
|
13 |
-
# Define models
|
14 |
MODELS = {
|
15 |
"athena-1": {
|
16 |
"name": "🦁 Atlas-Flash",
|
@@ -22,9 +20,9 @@ MODELS = {
|
|
22 |
},
|
23 |
}
|
24 |
|
25 |
-
|
26 |
-
USER_PFP = "user.png"
|
27 |
-
AI_PFP = "ai_pfp.png"
|
28 |
|
29 |
class AtlasInferenceApp:
|
30 |
def __init__(self):
|
@@ -61,17 +59,17 @@ class AtlasInferenceApp:
|
|
61 |
|
62 |
model_path = MODELS[model_key]["sizes"][model_size]
|
63 |
|
64 |
-
|
65 |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
66 |
model = AutoModelForCausalLM.from_pretrained(
|
67 |
model_path,
|
68 |
-
device_map="cpu",
|
69 |
-
torch_dtype=torch.float32,
|
70 |
trust_remote_code=True,
|
71 |
low_cpu_mem_usage=True
|
72 |
)
|
73 |
|
74 |
-
|
75 |
st.session_state.current_model.update({
|
76 |
"tokenizer": tokenizer,
|
77 |
"model": model,
|
@@ -89,7 +87,7 @@ class AtlasInferenceApp:
|
|
89 |
return "⚠️ Please select and load a model first"
|
90 |
|
91 |
try:
|
92 |
-
|
93 |
system_instruction = "You are Atlas, a helpful AI assistant trained to help the user. You are a Deepseek R1 fine-tune."
|
94 |
prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:"
|
95 |
|
@@ -101,8 +99,11 @@ class AtlasInferenceApp:
|
|
101 |
padding=True
|
102 |
)
|
103 |
|
|
|
|
|
|
|
104 |
with torch.no_grad():
|
105 |
-
|
106 |
input_ids=inputs.input_ids,
|
107 |
attention_mask=inputs.attention_mask,
|
108 |
max_new_tokens=max_tokens,
|
@@ -112,9 +113,13 @@ class AtlasInferenceApp:
|
|
112 |
do_sample=True,
|
113 |
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
|
114 |
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
118 |
except Exception as e:
|
119 |
return f"⚠️ Generation Error: {str(e)}"
|
120 |
finally:
|
@@ -154,7 +159,6 @@ class AtlasInferenceApp:
|
|
154 |
|
155 |
st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
|
156 |
|
157 |
-
# Display chat history
|
158 |
for message in st.session_state.chat_history:
|
159 |
with st.chat_message(
|
160 |
message["role"],
|
@@ -162,7 +166,6 @@ class AtlasInferenceApp:
|
|
162 |
):
|
163 |
st.markdown(message["content"])
|
164 |
|
165 |
-
# Input box for user messages
|
166 |
if prompt := st.chat_input("Message Atlas..."):
|
167 |
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
168 |
with st.chat_message("user", avatar=USER_PFP):
|
|
|
6 |
import re
|
7 |
import os
|
8 |
|
|
|
9 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
10 |
login(token=HF_TOKEN)
|
11 |
|
|
|
12 |
MODELS = {
|
13 |
"athena-1": {
|
14 |
"name": "🦁 Atlas-Flash",
|
|
|
20 |
},
|
21 |
}
|
22 |
|
23 |
+
|
24 |
+
USER_PFP = "user.png"
|
25 |
+
AI_PFP = "ai_pfp.png"
|
26 |
|
27 |
class AtlasInferenceApp:
|
28 |
def __init__(self):
|
|
|
59 |
|
60 |
model_path = MODELS[model_key]["sizes"][model_size]
|
61 |
|
62 |
+
|
63 |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
64 |
model = AutoModelForCausalLM.from_pretrained(
|
65 |
model_path,
|
66 |
+
device_map="cpu",
|
67 |
+
torch_dtype=torch.float32,
|
68 |
trust_remote_code=True,
|
69 |
low_cpu_mem_usage=True
|
70 |
)
|
71 |
|
72 |
+
|
73 |
st.session_state.current_model.update({
|
74 |
"tokenizer": tokenizer,
|
75 |
"model": model,
|
|
|
87 |
return "⚠️ Please select and load a model first"
|
88 |
|
89 |
try:
|
90 |
+
|
91 |
system_instruction = "You are Atlas, a helpful AI assistant trained to help the user. You are a Deepseek R1 fine-tune."
|
92 |
prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:"
|
93 |
|
|
|
99 |
padding=True
|
100 |
)
|
101 |
|
102 |
+
|
103 |
+
response_container = st.empty()
|
104 |
+
full_response = ""
|
105 |
with torch.no_grad():
|
106 |
+
for chunk in st.session_state.current_model["model"].generate(
|
107 |
input_ids=inputs.input_ids,
|
108 |
attention_mask=inputs.attention_mask,
|
109 |
max_new_tokens=max_tokens,
|
|
|
113 |
do_sample=True,
|
114 |
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
|
115 |
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
|
116 |
+
streamer=None, # Use a custom streamer for real-time updates
|
117 |
+
):
|
118 |
+
chunk_text = st.session_state.current_model["tokenizer"].decode(chunk, skip_special_tokens=True)
|
119 |
+
full_response += chunk_text
|
120 |
+
response_container.markdown(full_response)
|
121 |
+
|
122 |
+
return full_response.split("### Response:")[-1].strip()
|
123 |
except Exception as e:
|
124 |
return f"⚠️ Generation Error: {str(e)}"
|
125 |
finally:
|
|
|
159 |
|
160 |
st.markdown("*⚠️ CAUTION: Atlas is an experimental model and this is just a preview. Responses may not be expected. Please double-check sensitive information!*")
|
161 |
|
|
|
162 |
for message in st.session_state.chat_history:
|
163 |
with st.chat_message(
|
164 |
message["role"],
|
|
|
166 |
):
|
167 |
st.markdown(message["content"])
|
168 |
|
|
|
169 |
if prompt := st.chat_input("Message Atlas..."):
|
170 |
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
171 |
with st.chat_message("user", avatar=USER_PFP):
|