import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os from dotenv import load_dotenv load_dotenv() api_key = os.getenv("api_key") # App title and description st.title("I am Your GrowBuddy 🌱") st.write("Let me help you start gardening. Let's grow together!") def load_model(): try: tokenizer = AutoTokenizer.from_pretrained("KhunPop/Gardening", use_auth_token=api_key) model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=api_key) return tokenizer, model except Exception as e: st.error(f"Failed to load model: {e}") return None, None # Load model and tokenizer tokenizer, model = load_model() if not tokenizer or not model: st.stop() # Default to CPU, or use GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # Initialize session state messages if not already initialized if "messages" not in st.session_state: st.session_state.messages = [ {"role": "assistant", "content": "Hello there! How can I help you with gardening today?"} ] # Display the conversation history for message in st.session_state.messages: with st.chat_message(message["role"]): st.write(message["content"]) def generate_response(prompt): try: # Tokenize the input prompt inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) # Ensure the model is generating properly (without a target) outputs = model.generate(inputs["input_ids"], max_new_tokens=150, temperature=0.7, do_sample=True) # Decode the output to text response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response except Exception as e: st.error(f"Error during text generation: {e}") return "Sorry, I couldn't process your request." # User input field for asking questions user_input = st.chat_input("Type your gardening question here:") if user_input: # Display user message with st.chat_message("user"): st.write(user_input) # Generate and display assistant's response with st.chat_message("assistant"): with st.spinner("I'm gonna tell you..."): response = generate_response(user_input) st.write(response) # Update session state with the new conversation st.session_state.messages.append({"role": "user", "content": user_input}) st.session_state.messages.append({"role": "assistant", "content": response})