nataliegilbert commited on
Commit
ada82ad
·
verified ·
1 Parent(s): a73d244

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +88 -3
README.md CHANGED
@@ -8,10 +8,13 @@ tags: []
8
  <!-- Provide a quick summary of what the model is/does. -->
9
  This is the baseline model for the news source classification project.
10
 
11
- Code to load the model:
 
 
12
  from huggingface_hub import hf_hub_download
13
  import joblib
14
 
 
15
  repo_id='awngsz/baseline_model'
16
  filename='CIS5190_Proj2_AWNGSZ.joblib'
17
 
@@ -20,8 +23,90 @@ model=joblib.load(file_path)
20
 
21
  print(model)
22
 
23
- Code to perform inference:
24
- model.predict('test.csv')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  ## Model Details
27
 
 
8
  <!-- Provide a quick summary of what the model is/does. -->
9
  This is the baseline model for the news source classification project.
10
 
11
+ Please run the following evaluation pipeline code:
12
+
13
+ ##### START #####
14
  from huggingface_hub import hf_hub_download
15
  import joblib
16
 
17
+ #Load model from Huggingface
18
  repo_id='awngsz/baseline_model'
19
  filename='CIS5190_Proj2_AWNGSZ.joblib'
20
 
 
23
 
24
  print(model)
25
 
26
+ #Load test dataset (assuming the name is the same as the one in the Ed post)
27
+ test_df = pd.read_csv(file_path)
28
+
29
+ #Copying the naming convention from the sample dataset in the edpost
30
+ X_test = test_df['title']
31
+ y_test = test_df['labels']
32
+
33
+ #Load the embedding model from Huggingface
34
+ ############################################# Transformer: DistilBERT #############################################
35
+ from transformers import DistilBertTokenizer, DistilBertModel
36
+ # pytorch related packages
37
+ import torch
38
+ import torchvision
39
+ from torchvision import transforms, utils
40
+ import torch.nn as nn
41
+ import torch.optim as optim
42
+ import torchvision.transforms as transforms
43
+ from PIL import Image
44
+ from skimage import io, transform
45
+ from torchvision.io import read_image
46
+ from torch.utils.data import Dataset, DataLoader
47
+
48
+ def get_embeddings(text_all, tokenizer, model, max_len = 128):
49
+ '''
50
+ return: embeddings list
51
+ '''
52
+ embeddings = []
53
+ count = 0
54
+ print('Start embeddings:')
55
+ for text in text_all:
56
+ count += 1
57
+ if count % (len(text_all) // 10) == 0:
58
+ print(f'{count / len(text_all) * 100:.1f}% done ...')
59
+
60
+ model_input_token = tokenizer(
61
+ text,
62
+ add_special_tokens = True,
63
+ max_length = max_len,
64
+ padding = 'max_length',
65
+ truncation = True,
66
+ return_tensors = 'pt'
67
+ )
68
+
69
+ with torch.no_grad():
70
+ model_output = model(**model_input_token)
71
+ cls_embedding = model_output.last_hidden_state[:, 0, :]
72
+ cls_embedding = cls_embedding.squeeze().numpy()
73
+ embeddings.append(cls_embedding)
74
+
75
+ return embeddings
76
+
77
+ #Load the tokenizer and model from Hugging Face
78
+ tokenizer_DBERT = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
79
+ transformer_model_DBERT = DistilBertModel.from_pretrained('distilbert-base-uncased')
80
+
81
+ #Set the model to evaluation mode
82
+ transformer_model_DBERT.eval()
83
+
84
+ #Get the embeddings for the test data
85
+
86
+ max_len = max(len(text) for text in X_test)
87
+
88
+ #this may take awhile to run
89
+ X_test_embeddings_DBERT = get_embeddings(X_test, tokenizer_DBERT, transformer_model_DBERT, max_len = max_len)
90
+
91
+ prediction = model.predict(X_test_embeddings_DBERT)
92
+
93
+ #Accuracy
94
+ from sklearn.metrics import accuracy_score
95
+
96
+ label_map = {'NBC': 1, 'FoxNews': 0}
97
+
98
+ def compute_category_accuracy(y_true, y_pred, label):
99
+ n_correct = np.sum((y_true == label) & (y_pred == label))
100
+ n_total = np.sum(y_true == label)
101
+ cat_accuracy = n_correct / n_total
102
+ return cat_accuracy
103
+
104
+ #Print accuracy
105
+ print(f'Test accuracy: {accuracy_score(y_test, prediction) * 100:.2f}%')
106
+ print(f'Test accuracy for NBC: {compute_category_accuracy(y_test, prediction, label_map["NBC"]) * 100:.2f}%')
107
+ print(f'Test accuracy for FoxNews: {compute_category_accuracy(y_test, prediction, label_map["FoxNews"]) * 100:.2f}%')
108
+
109
+ ##### END ######
110
 
111
  ## Model Details
112