peterkros commited on
Commit
f9b0725
·
1 Parent(s): ae85134

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -3
app.py CHANGED
@@ -3,6 +3,98 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  import torch
4
  import pickle
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Model names for level1 and level2
8
  model_name_level1 = "peterkros/COFOG-bert2"
@@ -37,14 +129,29 @@ def predict(text):
37
  predicted_class_level1 = torch.argmax(probs_level1, dim=-1).item()
38
  predicted_label_level1 = label_encoder_level1.inverse_transform([predicted_class_level1])[0]
39
 
40
- # Predict Level2 (assuming level2 model uses both text and predicted level1 label)
41
  combined_input = text + " " + predicted_label_level1
42
  inputs_level2 = tokenizer_level2(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
43
  with torch.no_grad():
44
  outputs_level2 = model_level2(**inputs_level2)
45
  probs_level2 = torch.nn.functional.softmax(outputs_level2.logits, dim=-1)
46
- predicted_class_level2 = torch.argmax(probs_level2, dim=-1).item()
47
- predicted_label_level2 = label_encoder_level2.inverse_transform([predicted_class_level2])[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  combined_prediction = f"Level1: {predicted_label_level1} - Level2: {predicted_label_level2}"
49
  return combined_prediction
50
 
 
3
  import torch
4
  import pickle
5
 
6
+ level1_to_level2_mapping = {
7
+ "General public services": [
8
+ "Executive and legislative organs, financial and fiscal affairs, external affairs",
9
+ "Foreign economic aid",
10
+ "General services",
11
+ "Basic research",
12
+ "R&D General public services",
13
+ "General public services n.e.c.",
14
+ "Public debt transactions",
15
+ "Transfers of a general character between different levels of government"
16
+ ],
17
+ "Defence": [
18
+ "Military defence",
19
+ "Civil defence",
20
+ "Foreign military aid",
21
+ "R&D Defence",
22
+ "Defence n.e.c."
23
+ ],
24
+ "Public order and safety": [
25
+ "Police services",
26
+ "Fire-protection services",
27
+ "Law courts",
28
+ "Prisons",
29
+ "R&D Public order and safety",
30
+ "Public order and safety n.e.c."
31
+ ],
32
+ "Economic affairs": [
33
+ "General economic, commercial and labour affairs",
34
+ "Agriculture, forestry, fishing and hunting",
35
+ "Fuel and energy",
36
+ "Mining, manufacturing and construction",
37
+ "Transport",
38
+ "Communication",
39
+ "Other industries",
40
+ "R&D Economic affairs",
41
+ "Economic affairs n.e.c."
42
+ ],
43
+ "Environmental protection": [
44
+ "Waste management",
45
+ "Waste water management",
46
+ "Pollution abatement",
47
+ "Protection of biodiversity and landscape",
48
+ "R&D Environmental protection",
49
+ "Environmental protection n.e.c."
50
+ ],
51
+ "Housing and community amenities": [
52
+ "Housing development",
53
+ "Community development",
54
+ "Water supply",
55
+ "Street lighting",
56
+ "R&D Housing and community amenities",
57
+ "Housing and community amenities n.e.c."
58
+ ],
59
+ "Health": [
60
+ "Medical products, appliances and equipment",
61
+ "Outpatient services",
62
+ "Hospital services",
63
+ "Public health services",
64
+ "R&D Health",
65
+ "Health n.e.c."
66
+ ],
67
+ "Recreation, culture and religion": [
68
+ "Recreational and sporting services",
69
+ "Cultural services",
70
+ "Broadcasting and publishing services",
71
+ "Religious and other community services",
72
+ "R&D Recreation, culture and religion",
73
+ "Recreation, culture and religion n.e.c."
74
+ ],
75
+ "Education": [
76
+ "Pre-primary and primary education",
77
+ "Secondary education",
78
+ "Post-secondary non-tertiary education",
79
+ "Tertiary education",
80
+ "Education not definable by level",
81
+ "Subsidiary services to education",
82
+ "R&D Education",
83
+ "Education n.e.c."
84
+ ],
85
+ "Social protection": [
86
+ "Sickness and disability",
87
+ "Old age",
88
+ "Survivors",
89
+ "Family and children",
90
+ "Unemployment",
91
+ "Housing",
92
+ "Social exclusion n.e.c.",
93
+ "R&D Social protection",
94
+ "Social protection n.e.c."
95
+ ]
96
+ }
97
+
98
 
99
  # Model names for level1 and level2
100
  model_name_level1 = "peterkros/COFOG-bert2"
 
129
  predicted_class_level1 = torch.argmax(probs_level1, dim=-1).item()
130
  predicted_label_level1 = label_encoder_level1.inverse_transform([predicted_class_level1])[0]
131
 
132
+ # Predict Level2 (assuming level2 model uses both text and predicted level1 label)
133
  combined_input = text + " " + predicted_label_level1
134
  inputs_level2 = tokenizer_level2(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
135
+
136
  with torch.no_grad():
137
  outputs_level2 = model_level2(**inputs_level2)
138
  probs_level2 = torch.nn.functional.softmax(outputs_level2.logits, dim=-1)
139
+
140
+ # Extract the probabilities for the candidate level2 categories
141
+ level2_candidates = level1_to_level2_mapping.get(predicted_label_level1, [])
142
+ candidate_indices = [label_encoder_level2.transform([candidate])[0] for candidate in level2_candidates if candidate in label_encoder_level2.classes_]
143
+
144
+ # Filter the probabilities
145
+ filtered_probs = probs_level2[0, candidate_indices]
146
+
147
+ # Get the highest probability label from the filtered list
148
+ if len(filtered_probs) > 0:
149
+ highest_prob_index = torch.argmax(filtered_probs).item()
150
+ predicted_class_level2 = candidate_indices[highest_prob_index]
151
+ predicted_label_level2 = label_encoder_level2.inverse_transform([predicted_class_level2])[0]
152
+ else:
153
+ predicted_label_level2 = "n.e.c"
154
+
155
  combined_prediction = f"Level1: {predicted_label_level1} - Level2: {predicted_label_level2}"
156
  return combined_prediction
157