ManjinderUNCC commited on
Commit
d2e6a80
1 Parent(s): c2eb30b

Update python_Code/evaluate_model.py

Browse files
Files changed (1) hide show
  1. python_Code/evaluate_model.py +46 -55
python_Code/evaluate_model.py CHANGED
@@ -1,55 +1,46 @@
1
- import gradio as gr
2
-
3
- # Function to execute evaluate_model.py
4
- def evaluate_model_script():
5
- import spacy
6
- import jsonlines
7
- from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score
8
-
9
- # Load the trained spaCy model
10
- nlp = spacy.load("./my_trained_model")
11
-
12
- # Load the golden evaluation data
13
- golden_eval_data = []
14
- with jsonlines.open("data/goldenEval.jsonl") as reader:
15
- for record in reader:
16
- golden_eval_data.append(record)
17
-
18
- # Predict labels for each record using your model
19
- predicted_labels = []
20
- for record in golden_eval_data:
21
- text = record["text"]
22
- doc = nlp(text)
23
- predicted_labels.append(doc.cats)
24
-
25
- # Extract ground truth labels from the golden evaluation data
26
- true_labels = [record["accept"] for record in golden_eval_data]
27
-
28
- # Convert label format to match sklearn's classification report format
29
- true_labels_flat = [label[0] if label else "reject" for label in true_labels]
30
- predicted_labels_flat = [max(pred, key=pred.get) for pred in predicted_labels]
31
-
32
- # Calculate evaluation metrics
33
- accuracy = accuracy_score(true_labels_flat, predicted_labels_flat)
34
- precision = precision_score(true_labels_flat, predicted_labels_flat, average='weighted')
35
- recall = recall_score(true_labels_flat, predicted_labels_flat, average='weighted')
36
- f1 = f1_score(true_labels_flat, predicted_labels_flat, average='weighted')
37
-
38
- # Additional classification report
39
- report = classification_report(true_labels_flat, predicted_labels_flat)
40
-
41
- # Build the result dictionary
42
- result = {
43
- "accuracy": accuracy,
44
- "precision": precision,
45
- "recall": recall,
46
- "f1_score": f1,
47
- "detailed_classification_report": report
48
- }
49
-
50
- return result
51
-
52
- # Gradio Interface
53
- output = gr.outputs.Label(type="json", label="Evaluation Metrics")
54
- iface = gr.Interface(fn=evaluate_model_script, outputs=output, title="Evaluate Model Script")
55
- iface.launch()
 
1
+ import spacy
2
+ import jsonlines
3
+ from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score
4
+
5
+ # Load the trained spaCy model
6
+ nlp = spacy.load("./my_trained_model")
7
+
8
+ # Load the golden evaluation data
9
+ golden_eval_data = []
10
+ with jsonlines.open("data/goldenEval.jsonl") as reader:
11
+ for record in reader:
12
+ golden_eval_data.append(record)
13
+
14
+ # Predict labels for each record using your model
15
+ predicted_labels = []
16
+ for record in golden_eval_data:
17
+ text = record["text"]
18
+ doc = nlp(text)
19
+ predicted_labels.append(doc.cats)
20
+
21
+ # Extract ground truth labels from the golden evaluation data
22
+ true_labels = [record["accept"] for record in golden_eval_data]
23
+
24
+ # Convert label format to match sklearn's classification report format
25
+ true_labels_flat = [label[0] if label else "reject" for label in true_labels]
26
+ predicted_labels_flat = [max(pred, key=pred.get) for pred in predicted_labels]
27
+
28
+ # Calculate evaluation metrics
29
+ accuracy = accuracy_score(true_labels_flat, predicted_labels_flat)
30
+ precision = precision_score(true_labels_flat, predicted_labels_flat, average='weighted')
31
+ recall = recall_score(true_labels_flat, predicted_labels_flat, average='weighted')
32
+ f1 = f1_score(true_labels_flat, predicted_labels_flat, average='weighted')
33
+
34
+ # Additional classification report
35
+ report = classification_report(true_labels_flat, predicted_labels_flat)
36
+
37
+ # Print or save the evaluation metrics
38
+ print("Evaluation Metrics:")
39
+ print(f"Accuracy: {accuracy}")
40
+ print(f"Precision: {precision}")
41
+ print(f"Recall: {recall}")
42
+ print(f"F1-Score: {f1}")
43
+
44
+ # Print or save the detailed classification report
45
+ print("Detailed Classification Report:")
46
+ print(report)