Sushrut98 commited on
Commit
d9a8fa1
·
verified ·
1 Parent(s): 757cc16

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +101 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+ from transformers import BertTokenizer, BertModel, BertConfig
6
+ # from model import create_effnetb2_model
7
+ from timeit import default_timer as timer
8
+ # from typing import Tuple, Dict
9
+
10
+ # Setup class names
11
+ # class_names = ["pizza", "steak", "sushi"]
12
+
13
+ ### 2. Model and transforms preparation ###
14
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
15
+ do_lower_case=True)
16
+ # Create BERT model
17
+ model = BertForSequenceClassification.from_pretrained("bert-base-uncased",
18
+ num_labels=len(label_dict),
19
+ output_attentions=False,
20
+ output_hidden_states=False)
21
+ model.load_state_dict(torch.load('/content/finetuned_BERT_epoch_10.model', map_location=torch.device('cpu')))
22
+ ### 3. Predict function ###
23
+
24
+ # Create predict function
25
+ def predict(text) :
26
+ """Transforms and performs a prediction on Text.
27
+ """
28
+ # Start the timer
29
+ start_time = timer()
30
+ encoding = tokenizer.encode_plus(
31
+ text,
32
+ None,
33
+ add_special_tokens=True,
34
+ max_length=256,
35
+ pad_to_max_length=True,
36
+ return_token_type_ids=True,
37
+ return_tensors='pt'
38
+ )
39
+
40
+ model.eval()
41
+
42
+ loss_val_total = 0
43
+ predictions = []
44
+ # batch = tuple(prediction)
45
+
46
+ inputs = {'input_ids': encoding["input_ids"],
47
+ 'attention_mask': encoding["attention_mask"],
48
+ }
49
+
50
+ with torch.no_grad():
51
+ outputs = model(**inputs)
52
+
53
+ print(outputs)
54
+ # loss = outputs[0]
55
+ logits = outputs[0]
56
+ # loss_val_total += loss.item()
57
+
58
+ logits = logits.detach().cpu().numpy()
59
+ # print(logits)
60
+ # label_ids = inputs['labels'].cpu().numpy()
61
+ predictions.append(logits)
62
+ # true_vals.append(label_ids)
63
+
64
+ # loss_val_avg = loss_val_total/len(dataloader_val)
65
+
66
+ predictions = np.concatenate(predictions, axis=0)
67
+
68
+ preds_flat = np.argmax(predictions, axis=1).flatten()
69
+
70
+ if preds_flat==0:
71
+ prediction = "positive"
72
+ else:
73
+ prediction = "negative"
74
+
75
+ # Calculate the prediction time
76
+ pred_time = round(timer() - start_time, 5)
77
+
78
+ # Return the prediction dictionary and prediction time
79
+ return prediction, pred_time
80
+
81
+ ### 4. Gradio app ###
82
+
83
+ # Create title, description and article strings
84
+ title = "Sentiment Analysis"
85
+ description = "An EfficientNetB2 feature extractor computer vision model to classify images of food as pizza, steak or sushi."
86
+
87
+ # Create examples list from "examples/" directory
88
+ # example_list = [["examples/" + example] for example in os.listdir("examples")]
89
+
90
+ # Create the Gradio demo
91
+ demo = gr.Interface(fn=predict, # mapping function from input to output
92
+ inputs=["text", "checkbox"],
93
+ outputs=["text",
94
+ gr.Number(label="Prediction time (s)")],
95
+ # Create examples list from "examples/" directory
96
+ # examples=example_list,
97
+ title=title,
98
+ description=description)
99
+
100
+ # Launch the demo!
101
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ transformers==4.35.2
4
+ gradio==3.1.4