alinikkhah commited on
Commit
46eeeb0
·
verified ·
1 Parent(s): 52dcfd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -1,12 +1,22 @@
1
  import torch
2
  from transformers import BertForSequenceClassification
 
 
 
 
 
3
 
4
  # Load your BERT model
 
5
  model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
6
- model.load_state_dict(torch.load('bert_model_complete.pth', map_location=torch.device('cpu'), weights_only=True), strict=False)
 
 
 
 
 
7
  model.eval() # Set the model to evaluation mode
8
- import gradio as gr
9
- from transformers import BertTokenizer
10
 
11
  # Load the tokenizer
12
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
@@ -21,9 +31,6 @@ def predict(text):
21
 
22
  # Set up the Gradio interface
23
  interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification")
24
- import torch
25
- from transformers import BertForSequenceClassification, BertTokenizer
26
- import gradio as gr
27
 
28
  # Load model and tokenizer
29
  model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
 
1
  import torch
2
  from transformers import BertForSequenceClassification
3
+ import gradio as gr
4
+ from transformers import BertTokenizer
5
+ import torch
6
+ from transformers import BertForSequenceClassification, BertTokenizer
7
+ import gradio as gr
8
 
9
  # Load your BERT model
10
+ # Load the model architecture
11
  model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
12
+
13
+ # Load the state dict without weights_only
14
+ try:
15
+ model.load_state_dict(torch.load('bert_model_complete.pth', map_location=torch.device('cpu')), strict=False)
16
+ except Exception as e:
17
+ print(f"Error loading state dict: {e}")
18
  model.eval() # Set the model to evaluation mode
19
+
 
20
 
21
  # Load the tokenizer
22
  tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
 
31
 
32
  # Set up the Gradio interface
33
  interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification")
 
 
 
34
 
35
  # Load model and tokenizer
36
  model = BertForSequenceClassification.from_pretrained('bert-base-uncased')