Mattral commited on
Commit
ddcad02
·
verified ·
1 Parent(s): 6a3b5b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -13
app.py CHANGED
@@ -17,21 +17,23 @@ def load_model():
17
  st.error(f"Failed to load model: {e}")
18
  return None, None
19
 
 
20
  tokenizer, model = load_model()
21
 
22
  if not tokenizer or not model:
23
  st.stop()
24
 
25
- # Default to CPU
26
- device = torch.device("cpu")
27
- if model is not None:
28
- model = model.to(device)
29
 
30
- if "messages" not in st.session_state.keys():
 
31
  st.session_state.messages = [
32
  {"role": "assistant", "content": "Hello there! How can I help you with gardening today?"}
33
  ]
34
 
 
35
  for message in st.session_state.messages:
36
  with st.chat_message(message["role"]):
37
  st.write(message["content"])
@@ -39,11 +41,11 @@ for message in st.session_state.messages:
39
  def generate_response(prompt):
40
  try:
41
  # Tokenize the input prompt
42
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
43
 
44
- # Make sure that inputs are passed properly to the model
45
- outputs = model.generate(inputs["input_ids"], max_new_tokens=150, temperature=0.7)
46
-
47
  # Decode the output to text
48
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
  return response
@@ -51,7 +53,7 @@ def generate_response(prompt):
51
  st.error(f"Error during text generation: {e}")
52
  return "Sorry, I couldn't process your request."
53
 
54
-
55
  user_input = st.chat_input("Type your gardening question here:")
56
 
57
  if user_input:
@@ -59,10 +61,12 @@ if user_input:
59
  with st.chat_message("user"):
60
  st.write(user_input)
61
 
62
- with st.chat_message("assistant"):
 
63
  with st.spinner("I'm gonna tell you..."):
64
  response = generate_response(user_input)
65
  st.write(response)
66
 
67
- st.session_state.messages.append({"role": "user", "content": user_input})
68
- st.session_state.messages.append({"role": "assistant", "content": response})
 
 
17
  st.error(f"Failed to load model: {e}")
18
  return None, None
19
 
20
+ # Load model and tokenizer
21
  tokenizer, model = load_model()
22
 
23
  if not tokenizer or not model:
24
  st.stop()
25
 
26
+ # Default to CPU, or use GPU if available
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ model = model.to(device)
 
29
 
30
+ # Initialize session state messages if not already initialized
31
+ if "messages" not in st.session_state:
32
  st.session_state.messages = [
33
  {"role": "assistant", "content": "Hello there! How can I help you with gardening today?"}
34
  ]
35
 
36
+ # Display the conversation history
37
  for message in st.session_state.messages:
38
  with st.chat_message(message["role"]):
39
  st.write(message["content"])
 
41
  def generate_response(prompt):
42
  try:
43
  # Tokenize the input prompt
44
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
45
+
46
+ # Ensure the model is generating properly (without a target)
47
+ outputs = model.generate(inputs["input_ids"], max_new_tokens=150, temperature=0.7, do_sample=True)
48
 
 
 
 
49
  # Decode the output to text
50
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
  return response
 
53
  st.error(f"Error during text generation: {e}")
54
  return "Sorry, I couldn't process your request."
55
 
56
+ # User input field for asking questions
57
  user_input = st.chat_input("Type your gardening question here:")
58
 
59
  if user_input:
 
61
  with st.chat_message("user"):
62
  st.write(user_input)
63
 
64
+ # Generate and display assistant's response
65
+ with st.chat_message("assistant"):
66
  with st.spinner("I'm gonna tell you..."):
67
  response = generate_response(user_input)
68
  st.write(response)
69
 
70
+ # Update session state with the new conversation
71
+ st.session_state.messages.append({"role": "user", "content": user_input})
72
+ st.session_state.messages.append({"role": "assistant", "content": response})