AliArshad commited on
Commit
549186a
·
1 Parent(s): 174fdf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -16
app.py CHANGED
@@ -1,26 +1,15 @@
1
  import torch
2
  from transformers import XLNetTokenizer, XLNetForSequenceClassification
3
  import gradio as gr
4
- from pydrive.auth import GoogleAuth
5
- from pydrive.drive import GoogleDrive
6
 
7
- # Authenticate and create GoogleDrive instance
8
- gauth = GoogleAuth()
9
- gauth.LocalWebserverAuth()
10
- drive = GoogleDrive(gauth)
11
 
12
- # ID of the file in Google Drive
13
- file_id = '1-7O5gAFgcIzgJ68WkSSpmh1H6kJL6fAO' # Replace this with your file's ID from Google Drive
14
- destination_path = '/content/XLNet_model_project_Core.pt' # Path to save the downloaded model file
15
 
16
- # Download the model file from Google Drive
17
- downloaded_file = drive.CreateFile({'id': file_id})
18
- downloaded_file.GetContentFile(destination_path)
19
-
20
- # Load the saved model
21
  tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
22
- model = XLNetForSequenceClassification.from_pretrained('xlnet-base-cased', num_labels=2)
23
- model.load_state_dict(torch.load(destination_path))
24
  model.eval()
25
 
26
  # Function for prediction
 
1
  import torch
2
  from transformers import XLNetTokenizer, XLNetForSequenceClassification
3
  import gradio as gr
 
 
4
 
5
+ # Link to the saved model on Hugging Face Spaces
6
+ model_link = 'https://huggingface.co/spaces/AliArshad/SeverityPrediction/blob/main/severitypredictor.pt'
 
 
7
 
8
+ # Load the model directly from the link
9
+ model = XLNetForSequenceClassification.from_pretrained(model_link)
 
10
 
11
+ # Other model setup
 
 
 
 
12
  tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
 
 
13
  model.eval()
14
 
15
  # Function for prediction