TheresaQWQ commited on
Commit
950be38
·
verified ·
1 Parent(s): 4bd5edf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -8,6 +8,8 @@ model = AutoModelForSequenceClassification.from_pretrained(
8
  trust_remote_code=True,
9
  )
10
 
 
 
11
  model.eval()
12
 
13
  def compute_scores(query, documents):
@@ -24,7 +26,7 @@ def compute_scores(query, documents):
24
  documents_list = documents.split('\n')
25
  sentence_pairs = [[query, doc] for doc in documents_list]
26
  scores = model.compute_score(sentence_pairs, max_length=1024)
27
- return scores.tolist()
28
 
29
  # Define Gradio interface
30
  iface = gr.Interface(
@@ -40,4 +42,4 @@ iface = gr.Interface(
40
 
41
  # Launch the interface
42
  if __name__ == "__main__":
43
- iface.launch()
 
8
  trust_remote_code=True,
9
  )
10
 
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ model.to(device) # Move model to GPU if available, otherwise CPU
13
  model.eval()
14
 
15
  def compute_scores(query, documents):
 
26
  documents_list = documents.split('\n')
27
  sentence_pairs = [[query, doc] for doc in documents_list]
28
  scores = model.compute_score(sentence_pairs, max_length=1024)
29
+ return scores
30
 
31
  # Define Gradio interface
32
  iface = gr.Interface(
 
42
 
43
  # Launch the interface
44
  if __name__ == "__main__":
45
+ iface.launch()