Mattral commited on
Commit
791b656
·
verified ·
1 Parent(s): ca86f32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
- from PIL import Image
5
  import os
6
 
7
 
@@ -24,7 +23,7 @@ if not tokenizer or not model:
24
  st.stop()
25
 
26
  # Default to CPU
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  if model is not None:
29
  model = model.to(device)
30
 
@@ -39,14 +38,20 @@ for message in st.session_state.messages:
39
 
40
  def generate_response(prompt):
41
  try:
42
- inputs = tokenizer(prompt, return_tensors="pt").to(device) # Ensure inputs are moved to the device (CPU)
43
- outputs = model.generate(**inputs, max_new_tokens=150, temperature=0.7)
 
 
 
 
 
44
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
  return response
46
  except Exception as e:
47
  st.error(f"Error during text generation: {e}")
48
  return "Sorry, I couldn't process your request."
49
 
 
50
  user_input = st.chat_input("Type your gardening question here:")
51
 
52
  if user_input:
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
  import os
5
 
6
 
 
23
  st.stop()
24
 
25
  # Default to CPU
26
+ device = torch.device("cpu")
27
  if model is not None:
28
  model = model.to(device)
29
 
 
38
 
39
  def generate_response(prompt):
40
  try:
41
+ # Tokenize the input prompt
42
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
43
+
44
+ # Make sure that inputs are passed properly to the model
45
+ outputs = model.generate(inputs["input_ids"], max_new_tokens=150, temperature=0.7)
46
+
47
+ # Decode the output to text
48
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
  return response
50
  except Exception as e:
51
  st.error(f"Error during text generation: {e}")
52
  return "Sorry, I couldn't process your request."
53
 
54
+
55
  user_input = st.chat_input("Type your gardening question here:")
56
 
57
  if user_input: