frankjosh commited on
Commit
178a171
·
verified ·
1 Parent(s): 5d9493d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -15,6 +15,8 @@ st.set_page_config(page_title="Repository Recommender", layout="wide")
15
  def load_model():
16
  model_name = "Salesforce/codet5-small"
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
18
  model = AutoModel.from_pretrained(model_name).to("cuda")
19
  return tokenizer, model
20
 
 
15
  def load_model():
16
  model_name = "Salesforce/codet5-small"
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ # Check if GPU is available
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
  model = AutoModel.from_pretrained(model_name).to("cuda")
21
  return tokenizer, model
22