File size: 3,683 Bytes
3d0f7c4
a0c938b
 
3d0f7c4
abfd83b
 
c51711e
5d4ec37
b88e8f9
c51711e
3d0f7c4
 
 
 
a0c938b
 
3d0f7c4
a0c938b
 
 
 
 
7a3a8fa
a0c938b
 
 
 
3d0f7c4
a0c938b
 
3d0f7c4
a0c938b
 
3d0f7c4
a0c938b
3d0f7c4
 
a0c938b
 
 
 
c51711e
ddcad02
3d0f7c4
 
 
 
c51711e
3d0f7c4
 
 
 
a0c938b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c51711e
3d0f7c4
 
 
 
 
 
ddcad02
c51711e
a0c938b
 
3d0f7c4
c51711e
ddcad02
 
a0c938b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from dotenv import load_dotenv

# Load environment variables
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")

# App title and description
st.title("I am Your GrowBuddy 🌱")
st.write("Let me help you start gardening. Let's grow together!")

# Function to load model only once
def load_model():
    try:
        # If model and tokenizer are already in session state, return them
        if "tokenizer" in st.session_state and "model" in st.session_state:
            return st.session_state.tokenizer, st.session_state.model
        else:
            tokenizer = AutoTokenizer.from_pretrained("TheSheBots/UrbanGardening", use_auth_token=HF_TOKEN)
            model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN)
            # Store the model and tokenizer in session state
            st.session_state.tokenizer = tokenizer
            st.session_state.model = model
            return tokenizer, model
    except Exception as e:
        st.error(f"Failed to load model: {e}")
        return None, None

# Load model and tokenizer (cached)
tokenizer, model = load_model()

if not tokenizer or not model:
    st.stop()

# Default to CPU, or use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Initialize session state messages
if "messages" not in st.session_state:
    st.session_state.messages = [
        {"role": "assistant", "content": "Hello there! How can I help you with gardening today?"}
    ]

# Display conversation history
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.write(message["content"])

# Create a text area to display logs
log_box = st.empty()

# Function to generate response with debugging logs
def generate_response(prompt):
    try:
        # Tokenize input prompt with dynamic padding and truncation
        log_box.text_area("Debugging Logs", "Tokenizing the prompt...", height=200)
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)

        # Display tokenized inputs
        log_box.text_area("Debugging Logs", f"Tokenized inputs: {inputs['input_ids']}", height=200)
        
        # Generate output from model
        log_box.text_area("Debugging Logs", "Generating output...", height=200)
        outputs = model.generate(inputs["input_ids"], max_new_tokens=100, temperature=0.7, do_sample=True)

        # Display the raw output from the model
        log_box.text_area("Debugging Logs", f"Raw model output (tokens): {outputs}", height=200)

        # Decode and return response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Display the final decoded response
        log_box.text_area("Debugging Logs", f"Decoded response: {response}", height=200)
        
        return response
    except Exception as e:
        st.error(f"Error during text generation: {e}")
        return "Sorry, I couldn't process your request."

# User input field for gardening questions
user_input = st.chat_input("Type your gardening question here:")

if user_input:
    with st.chat_message("user"):
        st.write(user_input)

    with st.chat_message("assistant"):
        with st.spinner("Generating your answer..."):
            response = generate_response(user_input)
            st.write(response)

    # Update session state
    st.session_state.messages.append({"role": "user", "content": user_input})
    st.session_state.messages.append({"role": "assistant", "content": response})