|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
from huggingface_hub import login |
|
from threading import Thread |
|
import PyPDF2 |
|
import pandas as pd |
|
import torch |
|
import time |
|
|
|
|
|
st.set_page_config( |
|
page_title="WizNerd Insp", |
|
page_icon="π", |
|
layout="centered" |
|
) |
|
|
|
MODEL_NAME = "amiguel/SmolLM2-360M-concise-reasoning" |
|
|
|
|
|
|
|
|
|
|
|
st.title("π WizNerd Insp π") |
|
|
|
|
|
with st.sidebar: |
|
st.header("Authentication π") |
|
hf_token = st.text_input("Hugging Face Token", type="password", |
|
help="Get your token from https://huggingface.co/settings/tokens") |
|
|
|
st.header("Upload Documents π") |
|
uploaded_file = st.file_uploader( |
|
"Choose a PDF or XLSX file", |
|
type=["pdf", "xlsx"], |
|
label_visibility="collapsed" |
|
) |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
@st.cache_data |
|
def process_file(uploaded_file): |
|
if uploaded_file is None: |
|
return "" |
|
|
|
try: |
|
if uploaded_file.type == "application/pdf": |
|
pdf_reader = PyPDF2.PdfReader(uploaded_file) |
|
return "\n".join([page.extract_text() for page in pdf_reader.pages]) |
|
elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": |
|
df = pd.read_excel(uploaded_file) |
|
return df.to_markdown() |
|
except Exception as e: |
|
st.error(f"π Error processing file: {str(e)}") |
|
return "" |
|
|
|
|
|
@st.cache_resource |
|
def load_model(hf_token): |
|
try: |
|
if not hf_token: |
|
st.error("π Authentication required! Please provide a Hugging Face token.") |
|
return None |
|
|
|
login(token=hf_token) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_NAME, |
|
token=hf_token |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
device_map="auto", |
|
torch_dtype=torch.float16, |
|
token=hf_token |
|
) |
|
|
|
return model, tokenizer |
|
|
|
except Exception as e: |
|
st.error(f"π€ Model loading failed: {str(e)}") |
|
return None |
|
|
|
|
|
def generate_with_kv_cache(prompt, file_context, use_cache=True): |
|
full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:" |
|
|
|
streamer = TextIteratorStreamer( |
|
tokenizer, |
|
skip_prompt=True, |
|
skip_special_tokens=True |
|
) |
|
|
|
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) |
|
|
|
generation_kwargs = { |
|
**inputs, |
|
"max_new_tokens": 1024, |
|
"temperature": 0.7, |
|
"top_p": 0.9, |
|
"repetition_penalty": 1.1, |
|
"do_sample": True, |
|
"use_cache": use_cache, |
|
"streamer": streamer |
|
} |
|
|
|
Thread(target=model.generate, kwargs=generation_kwargs).start() |
|
return streamer |
|
|
|
|
|
for message in st.session_state.messages: |
|
try: |
|
avatar = "π€" if message["role"] == "user" else "π€" |
|
with st.chat_message(message["role"], avatar=avatar): |
|
st.markdown(message["content"]) |
|
except: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Ask your inspection question..."): |
|
if not hf_token: |
|
st.error("π Authentication required!") |
|
st.stop() |
|
|
|
|
|
if "model" not in st.session_state: |
|
model_data = load_model(hf_token) |
|
if model_data is None: |
|
st.error("Failed to load model. Please check your token and try again.") |
|
st.stop() |
|
|
|
st.session_state.model, st.session_state.tokenizer = model_data |
|
|
|
model = st.session_state.model |
|
tokenizer = st.session_state.tokenizer |
|
|
|
|
|
with st.chat_message("user", avatar="π€"): |
|
st.markdown(prompt) |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
file_context = process_file(uploaded_file) |
|
|
|
|
|
if model and tokenizer: |
|
try: |
|
with st.chat_message("assistant", avatar="π€"): |
|
start_time = time.time() |
|
streamer = generate_with_kv_cache(prompt, file_context, use_cache=True) |
|
|
|
response_container = st.empty() |
|
full_response = "" |
|
|
|
for chunk in streamer: |
|
cleaned_chunk = chunk.replace("<think>", "").replace("</think>", "").strip() |
|
full_response += cleaned_chunk + " " |
|
response_container.markdown(full_response + "β", unsafe_allow_html=True) |
|
|
|
|
|
end_time = time.time() |
|
st.caption(f"Generated in {end_time - start_time:.2f}s using KV caching") |
|
|
|
response_container.markdown(full_response) |
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
except Exception as e: |
|
st.error(f"β‘ Generation error: {str(e)}") |
|
else: |
|
st.error("π€ Model not loaded!") |