travol_API / app.py
Segizu's picture
simple modelo
57d5dc6
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import requests
import time
# Define model path for caching (Avoids reloading every app restart)
MODEL_PATH = "/mnt/data/Phi-4-Hindi"
TOKEN = os.environ.get("HF_TOKEN")
MODEL_NAME = "DrishtiSharma/Phi-4-Hindi-quantized"
# Load Model & Tokenizer Once
@st.cache_resource()
def load_model():
with st.spinner("Loading model... Please wait ⏳"):
try:
if not os.path.exists(MODEL_PATH):
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, token=TOKEN, trust_remote_code=True, torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=TOKEN)
model.save_pretrained(MODEL_PATH)
tokenizer.save_pretrained(MODEL_PATH)
else:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
except requests.exceptions.ConnectionError:
st.error("⚠️ Connection error! Unable to download the model. Please check your internet connection and try again.")
return None, None
except requests.exceptions.ReadTimeout:
st.error("⚠️ Read Timeout! The request took too long. Please try again later.")
return None, None
return model, tokenizer
# Load and move model to appropriate device
model, tok = load_model()
if model is None or tok is None:
st.stop()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
model = model.to(device)
except torch.cuda.OutOfMemoryError:
st.error("⚠️ CUDA Out of Memory! Running on CPU instead.")
device = torch.device("cpu")
model = model.to(device)
terminators = [tok.eos_token_id]
# Initialize session state if not set
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Chat function
def chat(message, temperature, do_sample, max_tokens):
"""Processes chat input and generates a response using the model."""
# Append new message to history
st.session_state.chat_history.append({"role": "user", "content": message})
# Convert chat history into model-friendly format
messages = tok.apply_chat_template(st.session_state.chat_history, tokenize=False, add_generation_prompt=True)
model_inputs = tok([messages], return_tensors="pt").to(device)
# Initialize streamer for token-wise response
streamer = TextIteratorStreamer(tok, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
# Define generation parameters
generate_kwargs = {
"inputs": model_inputs["input_ids"],
"streamer": streamer,
"max_new_tokens": max_tokens,
"do_sample": do_sample,
"temperature": temperature,
"eos_token_id": terminators,
}
if temperature == 0:
generate_kwargs["do_sample"] = False
# Generate response asynchronously
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Collect response as it streams
response_text = ""
for new_text in streamer:
response_text += new_text
yield response_text
# Save the assistant's response to session history
st.session_state.chat_history.append({"role": "assistant", "content": response_text})
# UI Setup
st.title("πŸ’¬ Chat With Phi-4-Hindi")
st.success("βœ… Model is READY to chat!")
st.markdown("Chat with [large-traversaal/Phi-4-Hindi](https://huggingface.co/large-traversaal/Phi-4-Hindi)")
# Sidebar Chat Settings
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.3, 0.1)
do_sample = st.sidebar.checkbox("Use Sampling", value=True)
max_tokens = st.sidebar.slider("Max Tokens", 128, 4096, 512, 1)
text_color = st.sidebar.selectbox("Text Color", ["Red", "Black", "Blue", "Green", "Purple"], index=0)
dark_mode = st.sidebar.checkbox("πŸŒ™ Dark Mode", value=False)
# Function to format chat messages
def get_html_text(text, color):
return f'<p style="color: {color.lower()}; font-size: 16px;">{text}</p>'
# Display chat history
for msg in st.session_state.chat_history:
role = "πŸ‘€" if msg["role"] == "user" else "πŸ€–"
st.markdown(get_html_text(f"**{role}:** {msg['content']}", text_color if role == "πŸ€–" else "black"), unsafe_allow_html=True)
# User Input Handling
user_input = st.text_input("Type your message:", "")
if st.button("Send"):
if user_input.strip():
st.session_state.chat_history.append({"role": "user", "content": user_input})
# Display chatbot response
with st.spinner("Generating response... πŸ€–πŸ’­"):
response_generator = chat(user_input, temperature, do_sample, max_tokens)
final_response = ""
for output in response_generator:
final_response = output # Store latest output
st.success("βœ… Response generated!")
# Add generated response to session state
st.rerun()
if st.button("🧹 Clear Chat"):
with st.spinner("Clearing chat history..."):
st.session_state.chat_history = []
st.success("βœ… Chat history cleared!")
st.rerun()