Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
api_key = os.getenv("api_key") | |
# App title and description | |
st.title("I am Your GrowBuddy 🌱") | |
st.write("Let me help you start gardening. Let's grow together!") | |
def load_model(): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("KhunPop/Gardening", use_auth_token=api_key) | |
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=api_key) | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Failed to load model: {e}") | |
return None, None | |
# Load model and tokenizer | |
tokenizer, model = load_model() | |
if not tokenizer or not model: | |
st.stop() | |
# Default to CPU, or use GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
# Initialize session state messages if not already initialized | |
if "messages" not in st.session_state: | |
st.session_state.messages = [ | |
{"role": "assistant", "content": "Hello there! How can I help you with gardening today?"} | |
] | |
# Display the conversation history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
def generate_response(prompt): | |
try: | |
# Tokenize the input prompt | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) | |
# Ensure the model is generating properly (without a target) | |
outputs = model.generate(inputs["input_ids"], max_new_tokens=150, temperature=0.7, do_sample=True) | |
# Decode the output to text | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
except Exception as e: | |
st.error(f"Error during text generation: {e}") | |
return "Sorry, I couldn't process your request." | |
# User input field for asking questions | |
user_input = st.chat_input("Type your gardening question here:") | |
if user_input: | |
# Display user message | |
with st.chat_message("user"): | |
st.write(user_input) | |
# Generate and display assistant's response | |
with st.chat_message("assistant"): | |
with st.spinner("I'm gonna tell you..."): | |
response = generate_response(user_input) | |
st.write(response) | |
# Update session state with the new conversation | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |