Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# App title and description | |
st.title("I am Your GrowBuddy 🌱") | |
st.write("Let me help you start gardening. Let's grow together!") | |
# Function to load model only once | |
def load_model(): | |
try: | |
# If model and tokenizer are already in session state, return them | |
if "tokenizer" in st.session_state and "model" in st.session_state: | |
return st.session_state.tokenizer, st.session_state.model | |
else: | |
tokenizer = AutoTokenizer.from_pretrained("TheSheBots/UrbanGardening", use_auth_token=HF_TOKEN) | |
model = AutoModelForCausalLM.from_pretrained("CopyleftCultivars/Gemma2B-NaturalFarmerV1", use_auth_token=HF_TOKEN) | |
# Store the model and tokenizer in session state | |
st.session_state.tokenizer = tokenizer | |
st.session_state.model = model | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Failed to load model: {e}") | |
return None, None | |
# Load model and tokenizer (cached) | |
tokenizer, model = load_model() | |
if not tokenizer or not model: | |
st.stop() | |
# Default to CPU, or use GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
# Initialize session state messages | |
if "messages" not in st.session_state: | |
st.session_state.messages = [ | |
{"role": "assistant", "content": "Hello there! How can I help you with gardening today?"} | |
] | |
# Display conversation history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
# Create a text area to display logs | |
log_box = st.empty() | |
# Function to generate response with debugging logs | |
def generate_response(prompt): | |
try: | |
# Tokenize input prompt with dynamic padding and truncation | |
log_box.text_area("Debugging Logs", "Tokenizing the prompt...", height=200) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) | |
# Display tokenized inputs | |
log_box.text_area("Debugging Logs", f"Tokenized inputs: {inputs['input_ids']}", height=200) | |
# Generate output from model | |
log_box.text_area("Debugging Logs", "Generating output...", height=200) | |
outputs = model.generate(inputs["input_ids"], max_new_tokens=100, temperature=0.7, do_sample=True) | |
# Display the raw output from the model | |
log_box.text_area("Debugging Logs", f"Raw model output (tokens): {outputs}", height=200) | |
# Decode and return response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Display the final decoded response | |
log_box.text_area("Debugging Logs", f"Decoded response: {response}", height=200) | |
return response | |
except Exception as e: | |
st.error(f"Error during text generation: {e}") | |
return "Sorry, I couldn't process your request." | |
# User input field for gardening questions | |
user_input = st.chat_input("Type your gardening question here:") | |
if user_input: | |
with st.chat_message("user"): | |
st.write(user_input) | |
with st.chat_message("assistant"): | |
with st.spinner("Generating your answer..."): | |
response = generate_response(user_input) | |
st.write(response) | |
# Update session state | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |