Update app.py
Browse files
app.py
CHANGED
@@ -1,57 +1,68 @@
|
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
-
import time
|
4 |
import os
|
|
|
5 |
from threading import Thread
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
7 |
-
from huggingface_hub import login
|
8 |
|
9 |
-
#
|
10 |
-
HF_TOKEN =
|
11 |
|
12 |
-
#
|
13 |
-
st.set_page_config(
|
14 |
-
page_title="GM Fine-tune Assistant π",
|
15 |
-
page_icon="π",
|
16 |
-
layout="centered"
|
17 |
-
)
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
# Avatars
|
22 |
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
|
23 |
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
|
24 |
|
25 |
-
#
|
26 |
-
login(token=HF_TOKEN)
|
27 |
-
|
28 |
-
# Load Model
|
29 |
@st.cache_resource
|
30 |
-
def
|
31 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
32 |
model = AutoModelForCausalLM.from_pretrained(
|
33 |
-
|
34 |
device_map="auto",
|
35 |
-
torch_dtype=torch.bfloat16,
|
|
|
36 |
token=HF_TOKEN
|
37 |
)
|
38 |
return model, tokenizer
|
39 |
|
40 |
-
model, tokenizer =
|
41 |
|
42 |
-
#
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
#
|
47 |
-
def generate_response(
|
48 |
-
streamer = TextIteratorStreamer(
|
49 |
-
|
50 |
-
skip_prompt=True,
|
51 |
-
skip_special_tokens=True
|
52 |
-
)
|
53 |
-
|
54 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
55 |
|
56 |
generation_kwargs = {
|
57 |
"input_ids": inputs["input_ids"],
|
@@ -61,55 +72,53 @@ def generate_response(prompt, model, tokenizer):
|
|
61 |
"top_p": 0.9,
|
62 |
"repetition_penalty": 1.1,
|
63 |
"do_sample": True,
|
64 |
-
"streamer": streamer
|
65 |
}
|
66 |
|
67 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
68 |
thread.start()
|
69 |
return streamer
|
70 |
|
71 |
-
#
|
72 |
-
|
73 |
-
|
74 |
-
with st.chat_message(message["role"], avatar=avatar):
|
75 |
-
st.markdown(message["content"])
|
76 |
|
77 |
-
# Chat
|
78 |
-
|
|
|
|
|
|
|
79 |
|
80 |
-
|
|
|
|
|
81 |
with st.chat_message("user", avatar=USER_AVATAR):
|
82 |
st.markdown(prompt)
|
83 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
84 |
|
85 |
-
#
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
except Exception as e:
|
113 |
-
st.error(f"β‘ Generation error: {str(e)}")
|
114 |
-
else:
|
115 |
-
st.error("π€ Model not loaded!")
|
|
|
1 |
+
# π DigiTwin - TinyLLaMA ChatML Inference App
|
2 |
import streamlit as st
|
3 |
import torch
|
|
|
4 |
import os
|
5 |
+
import time
|
6 |
from threading import Thread
|
7 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
|
|
8 |
|
9 |
+
# --- Hugging Face Token (use Streamlit secrets for secure deployment) ---
|
10 |
+
HF_TOKEN = st.secrets["HF_TOKEN"]
|
11 |
|
12 |
+
# --- Streamlit Page Configuration ---
|
13 |
+
st.set_page_config(page_title="DigiTwin - TinyLLaMA", page_icon="π¦", layout="centered")
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
# --- App Logo and Title ---
|
16 |
+
# st.image("assets/valonylabz_logo.png", width=160) # Optional: Add your logo
|
17 |
+
st.title("π DigiTwin - TinyLLaMA ChatML π")
|
18 |
+
|
19 |
+
# --- Model Path ---
|
20 |
+
MODEL_ID = "amiguel/TinyLLaMA-110M-general-knowledge"
|
21 |
+
|
22 |
+
# --- System Prompt (ChatML format) ---
|
23 |
+
SYSTEM_PROMPT = (
|
24 |
+
"You are DigiTwin, the digital twin of Ataliba, an inspection engineer with over 17 years "
|
25 |
+
"of experience in mechanical integrity, reliability, piping, and asset management. "
|
26 |
+
"Be precise, practical, and technical. Provide advice aligned with industry best practices."
|
27 |
+
)
|
28 |
|
29 |
+
# --- Avatars ---
|
30 |
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
|
31 |
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
|
32 |
|
33 |
+
# --- Load Model & Tokenizer ---
|
|
|
|
|
|
|
34 |
@st.cache_resource
|
35 |
+
def load_tinyllama_model():
|
36 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
37 |
+
MODEL_ID,
|
38 |
+
trust_remote_code=True,
|
39 |
+
use_fast=False, # SentencePiece tokenizer
|
40 |
+
token=HF_TOKEN
|
41 |
+
)
|
42 |
model = AutoModelForCausalLM.from_pretrained(
|
43 |
+
MODEL_ID,
|
44 |
device_map="auto",
|
45 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
46 |
+
trust_remote_code=True,
|
47 |
token=HF_TOKEN
|
48 |
)
|
49 |
return model, tokenizer
|
50 |
|
51 |
+
model, tokenizer = load_tinyllama_model()
|
52 |
|
53 |
+
# --- Build ChatML Prompt ---
|
54 |
+
def build_chatml_prompt(messages):
|
55 |
+
prompt = f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
|
56 |
+
for msg in messages:
|
57 |
+
role = msg["role"]
|
58 |
+
prompt += f"<|im_start|>{role}\n{msg['content']}<|im_end|>\n"
|
59 |
+
prompt += "<|im_start|>assistant\n"
|
60 |
+
return prompt
|
61 |
|
62 |
+
# --- Generate Response with Streaming ---
|
63 |
+
def generate_response(prompt_text):
|
64 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
65 |
+
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
generation_kwargs = {
|
68 |
"input_ids": inputs["input_ids"],
|
|
|
72 |
"top_p": 0.9,
|
73 |
"repetition_penalty": 1.1,
|
74 |
"do_sample": True,
|
75 |
+
"streamer": streamer,
|
76 |
}
|
77 |
|
78 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
79 |
thread.start()
|
80 |
return streamer
|
81 |
|
82 |
+
# --- Session State ---
|
83 |
+
if "messages" not in st.session_state:
|
84 |
+
st.session_state.messages = []
|
|
|
|
|
85 |
|
86 |
+
# --- Display Chat History ---
|
87 |
+
for msg in st.session_state.messages:
|
88 |
+
avatar = USER_AVATAR if msg["role"] == "user" else BOT_AVATAR
|
89 |
+
with st.chat_message(msg["role"], avatar=avatar):
|
90 |
+
st.markdown(msg["content"])
|
91 |
|
92 |
+
# --- Handle Chat Input ---
|
93 |
+
if prompt := st.chat_input("Ask DigiTwin about inspection, piping, or reliability..."):
|
94 |
+
# Show user message
|
95 |
with st.chat_message("user", avatar=USER_AVATAR):
|
96 |
st.markdown(prompt)
|
97 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
98 |
|
99 |
+
# Build ChatML-formatted prompt
|
100 |
+
prompt_text = build_chatml_prompt(st.session_state.messages)
|
101 |
+
|
102 |
+
with st.chat_message("assistant", avatar=BOT_AVATAR):
|
103 |
+
start = time.time()
|
104 |
+
streamer = generate_response(prompt_text)
|
105 |
+
|
106 |
+
response_area = st.empty()
|
107 |
+
full_response = ""
|
108 |
+
|
109 |
+
for chunk in streamer:
|
110 |
+
full_response += chunk.replace("<|im_end|>", "").strip() + " "
|
111 |
+
response_area.markdown(full_response + "β", unsafe_allow_html=True)
|
112 |
+
|
113 |
+
end = time.time()
|
114 |
+
input_tokens = len(tokenizer(prompt_text)["input_ids"])
|
115 |
+
output_tokens = len(tokenizer(full_response)["input_ids"])
|
116 |
+
speed = output_tokens / (end - start)
|
117 |
+
|
118 |
+
st.caption(
|
119 |
+
f"π Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
|
120 |
+
f"π Speed: {speed:.1f} tokens/sec"
|
121 |
+
)
|
122 |
+
|
123 |
+
response_area.markdown(full_response.strip())
|
124 |
+
st.session_state.messages.append({"role": "assistant", "content": full_response.strip()})
|
|
|
|
|
|
|
|
|
|