namannn's picture
Update app.py
769c112 verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
@st.cache_resource
def load_model_and_tokenizer():
"""
Load model and tokenizer with Streamlit's caching to prevent reloading.
"""
try:
tokenizer = AutoTokenizer.from_pretrained(
"namannn/llama2-13b-hyperbolic-cluster-pruned",
use_fast=True, # Use fast tokenizer if available
trust_remote_code=True # Trust remote code for custom tokenizers
)
# Ensure pad_token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
"namannn/llama2-13b-hyperbolic-cluster-pruned",
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True # Trust remote code for custom models
)
return tokenizer, model
except Exception as e:
st.error(f"Error loading model: {e}")
raise
def generate_text(prompt, tokenizer, model, max_length):
"""
Generate text using the loaded model and tokenizer with detailed error handling.
"""
try:
# Ensure input is on the correct device
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate text with more explicit parameters
with torch.no_grad(): # Disable gradient calculation
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask"),
max_length=max_length + len(inputs["input_ids"][0]),
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
# Decode the generated text
generated_text = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
return generated_text.strip()
except Exception as e:
st.error(f"Error generating text: {e}")
return None
def main():
# Set page configuration
st.set_page_config(page_title="LLaMa2 Text Generation", page_icon="✍️")
# Page title and description
st.title("Text Generation with LLaMa2-13b Hyperbolic Model")
st.write("Enter a prompt below and the model will generate text.")
# Load model and tokenizer
try:
tokenizer, model = load_model_and_tokenizer()
except Exception as e:
st.error(f"Failed to load model: {e}")
return
# System information
st.sidebar.header("System Information")
st.sidebar.write(f"Device: {model.device}")
st.sidebar.write(f"Model Dtype: {model.dtype}")
# User input for prompt
prompt = st.text_area("Input Prompt", "Once upon a time, in a land far away")
# Slider for controlling the length of the output
max_length = st.slider("Max Length of Generated Text", min_value=50, max_value=500, value=150)
# Button to trigger text generation
if st.button("Generate Text"):
if prompt:
try:
# Generate text
generated_text = generate_text(prompt, tokenizer, model, max_length)
# Display generated text
if generated_text:
st.subheader("Generated Text:")
st.write(generated_text)
else:
st.warning("No text was generated. Please check the input and try again.")
except Exception as e:
st.error(f"Unexpected error during text generation: {e}")
else:
st.warning("Please enter a prompt to generate text.")
if __name__ == "__main__":
main()