ncoria commited on
Commit
2aacf30
·
verified ·
1 Parent(s): 389f3f5

edit embeddings_df key for train

Browse files

in order to prevent issues between pages

Files changed (1) hide show
  1. train_model.py +159 -159
train_model.py CHANGED
@@ -1,159 +1,159 @@
1
- import os
2
- import io
3
- import pickle
4
- import regex
5
- import streamlit as st
6
- import plotly.express as px
7
- import numpy as np
8
- import pandas as pd
9
- import torch
10
- from utils.seqIo import seqIo_reader
11
- import pandas as pd
12
- from PIL import Image
13
- from pathlib import Path
14
- from transformers import AutoProcessor, AutoModel
15
- from tqdm import tqdm
16
- from sklearn.svm import SVC
17
- from sklearn.model_selection import train_test_split
18
- from sklearn.metrics import accuracy_score, classification_report
19
- from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, generate_embeddings_stream_io
20
-
21
- # --server.maxUploadSize 3000
22
-
23
- def get_unique_labels(label_list: list[str]):
24
- label_set = set()
25
- for label in label_list:
26
- individual_labels = label.split('||')
27
- for individual_label in individual_labels:
28
- label_set.add(individual_label)
29
- return list(label_set)
30
-
31
- @st.cache_data
32
- def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42):
33
- return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state)
34
-
35
- @st.cache_resource
36
- def train_model(X_train, y_train, random_state=42):
37
- # Train SVM Classifier
38
- svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True, verbose=True)
39
- svm_clf.fit(X_train, y_train)
40
- return svm_clf
41
-
42
- def pickle_model(model):
43
- pickled = io.BytesIO()
44
- pickle.dump(model, pickled)
45
- return pickled
46
-
47
- if "embeddings_df" not in st.session_state:
48
- st.session_state.embeddings_df = None
49
-
50
- if "svm_clf" not in st.session_state:
51
- st.session_state.svm_clf = None
52
- st.session_state.report_df = None
53
- st.session_state.accuracy = None
54
-
55
- st.title('batik: frame classifier training')
56
-
57
- st.text("Upload files to train classifier on.")
58
- with st.form('embedding_generation_settings'):
59
- seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'])
60
- annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
61
- downsample_rate = st.number_input('Downsample Rate',value=4)
62
- submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
63
-
64
- st.markdown("**(Optional)** Upload embeddings.")
65
- embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv'])
66
-
67
- if submit_embed_settings and seq_file is not None and annot_files is not None:
68
- video_embeddings, video_frames = generate_embeddings_stream_io([seq_file],
69
- "SLIP",
70
- downsample_rate,
71
- False)
72
-
73
- fnames = [seq_file.name]
74
- embeddings_df = create_embeddings_csv_io(out="file",
75
- fnames=fnames,
76
- embeddings=video_embeddings,
77
- frames=video_frames,
78
- annotations=[annot_files],
79
- test_fnames=None,
80
- views=None,
81
- conditions=None,
82
- downsample_rate=downsample_rate)
83
- st.session_state.embeddings_df = embeddings_df
84
-
85
- elif embeddings_csv is not None:
86
- embeddings_df = pd.read_csv(embeddings_csv)
87
- st.session_state.embeddings_df = embeddings_df
88
- else:
89
- st.text('Please upload file(s).')
90
-
91
- st.divider()
92
-
93
- if st.session_state.embeddings_df is not None:
94
- st.subheader("specify dataset preprocessing options")
95
- st.text("Select frames with label(s) to include:")
96
-
97
- with st.form('train_settings'):
98
- label_list = st.session_state.embeddings_df['Label'].to_list()
99
- unique_label_list = get_unique_labels(label_list)
100
- specified_classes = st.multiselect("Label(s) included:", options=unique_label_list)
101
-
102
- st.text("Select label(s) that should be removed:")
103
- classes_to_remove = st.multiselect("Label(s) excluded:", options=unique_label_list)
104
-
105
- max_class_size = st.number_input("(Optional) Specify max class size:", value=None)
106
-
107
- shuffle_data = st.toggle("Shuffle data:")
108
-
109
- train_model_clicked = st.form_submit_button("Train Model")
110
-
111
- if train_model_clicked:
112
- kwargs = {'embeddings_df' : st.session_state.embeddings_df,
113
- 'specified_classes' : specified_classes,
114
- 'classes_to_remove' : classes_to_remove,
115
- 'max_class_size' : max_class_size,
116
- 'animal_state' : None,
117
- 'view' : None,
118
- 'shuffle_data' : shuffle_data,
119
- 'test_videos' : None}
120
- train_embeds, train_labels, train_images, _, _, _ = process_dataset_in_mem(**kwargs)
121
- # Convert labels to numerical values
122
- label_to_appear_first = 'other'
123
- unique_labels = set(train_labels)
124
- unique_labels.discard(label_to_appear_first)
125
-
126
- label_to_index = {label_to_appear_first: 0}
127
-
128
- label_to_index.update({label: idx + 1 for idx, label in enumerate(unique_labels)})
129
- index_to_label = {idx: label for label, idx in label_to_index.items()}
130
- numerical_labels = np.array([label_to_index[label] for label in train_labels])
131
-
132
- print("Label Valence: ", label_to_index)
133
- # Split data into train and test sets
134
- X_train, X_test, y_train, y_test = get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42)
135
- with st.spinner("Model training in progress..."):
136
- svm_clf = train_model(X_train, y_train)
137
-
138
- # Predict on the test set
139
- with st.spinner("In progress..."):
140
- y_pred = svm_clf.predict(X_test)
141
- accuracy = accuracy_score(y_test, y_pred)
142
- report = classification_report(y_test, y_pred, target_names=[index_to_label[idx] for idx in range(len(label_to_index))], output_dict=True)
143
- report_df = pd.DataFrame(report).transpose()
144
-
145
- # save results to session state
146
- st.session_state.svm_clf = svm_clf
147
- st.session_state.report_df = report_df
148
- st.session_state.accuracy = accuracy
149
-
150
- if st.session_state.svm_clf is not None:
151
- pickled_model = pickle_model(st.session_state.svm_clf)
152
-
153
- st.text(f"Eval Accuracy: {st.session_state.accuracy}")
154
- st.subheader("Classification Report:")
155
- st.dataframe(st.session_state.report_df)
156
-
157
- st.download_button("Download model as .pkl file",
158
- data=pickled_model,
159
- file_name=f"{'_'.join(specified_classes)}_classifier.pkl")
 
1
+ import os
2
+ import io
3
+ import pickle
4
+ import regex
5
+ import streamlit as st
6
+ import plotly.express as px
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from utils.seqIo import seqIo_reader
11
+ import pandas as pd
12
+ from PIL import Image
13
+ from pathlib import Path
14
+ from transformers import AutoProcessor, AutoModel
15
+ from tqdm import tqdm
16
+ from sklearn.svm import SVC
17
+ from sklearn.model_selection import train_test_split
18
+ from sklearn.metrics import accuracy_score, classification_report
19
+ from utils.utils import create_embeddings_csv_io, process_dataset_in_mem, generate_embeddings_stream_io
20
+
21
+ # --server.maxUploadSize 3000
22
+
23
+ def get_unique_labels(label_list: list[str]):
24
+ label_set = set()
25
+ for label in label_list:
26
+ individual_labels = label.split('||')
27
+ for individual_label in individual_labels:
28
+ label_set.add(individual_label)
29
+ return list(label_set)
30
+
31
+ @st.cache_data
32
+ def get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42):
33
+ return train_test_split(train_embeds, numerical_labels, test_size=test_size, random_state=random_state)
34
+
35
+ @st.cache_resource
36
+ def train_model(X_train, y_train, random_state=42):
37
+ # Train SVM Classifier
38
+ svm_clf = SVC(kernel='rbf', random_state=random_state, probability=True, verbose=True)
39
+ svm_clf.fit(X_train, y_train)
40
+ return svm_clf
41
+
42
+ def pickle_model(model):
43
+ pickled = io.BytesIO()
44
+ pickle.dump(model, pickled)
45
+ return pickled
46
+
47
+ if "embeddings_df_train" not in st.session_state:
48
+ st.session_state.embeddings_df_train = None
49
+
50
+ if "svm_clf" not in st.session_state:
51
+ st.session_state.svm_clf = None
52
+ st.session_state.report_df = None
53
+ st.session_state.accuracy = None
54
+
55
+ st.title('batik: frame classifier training')
56
+
57
+ st.text("Upload files to train classifier on.")
58
+ with st.form('embedding_generation_settings'):
59
+ seq_file = st.file_uploader("Choose a video file", type=['seq', 'mp4'])
60
+ annot_files = st.file_uploader("Choose an annotation File", type=['annot','csv'], accept_multiple_files=True)
61
+ downsample_rate = st.number_input('Downsample Rate',value=4)
62
+ submit_embed_settings = st.form_submit_button('Create Embeddings', type='secondary')
63
+
64
+ st.markdown("**(Optional)** Upload embeddings.")
65
+ embeddings_csv = st.file_uploader("Choose a .csv File", type=['csv'])
66
+
67
+ if submit_embed_settings and seq_file is not None and annot_files is not None:
68
+ video_embeddings, video_frames = generate_embeddings_stream_io([seq_file],
69
+ "SLIP",
70
+ downsample_rate,
71
+ False)
72
+
73
+ fnames = [seq_file.name]
74
+ embeddings_df = create_embeddings_csv_io(out="file",
75
+ fnames=fnames,
76
+ embeddings=video_embeddings,
77
+ frames=video_frames,
78
+ annotations=[annot_files],
79
+ test_fnames=None,
80
+ views=None,
81
+ conditions=None,
82
+ downsample_rate=downsample_rate)
83
+ st.session_state.embeddings_df_train = embeddings_df
84
+
85
+ elif embeddings_csv is not None:
86
+ embeddings_df = pd.read_csv(embeddings_csv)
87
+ st.session_state.embeddings_df_train = embeddings_df
88
+ else:
89
+ st.text('Please upload file(s).')
90
+
91
+ st.divider()
92
+
93
+ if st.session_state.embeddings_df_train is not None:
94
+ st.subheader("specify dataset preprocessing options")
95
+ st.text("Select frames with label(s) to include:")
96
+
97
+ with st.form('train_settings'):
98
+ label_list = st.session_state.embeddings_df_train['Label'].to_list()
99
+ unique_label_list = get_unique_labels(label_list)
100
+ specified_classes = st.multiselect("Label(s) included:", options=unique_label_list)
101
+
102
+ st.text("Select label(s) that should be removed:")
103
+ classes_to_remove = st.multiselect("Label(s) excluded:", options=unique_label_list)
104
+
105
+ max_class_size = st.number_input("(Optional) Specify max class size:", value=None)
106
+
107
+ shuffle_data = st.toggle("Shuffle data:")
108
+
109
+ train_model_clicked = st.form_submit_button("Train Model")
110
+
111
+ if train_model_clicked:
112
+ kwargs = {'embeddings_df' : st.session_state.embeddings_df_train,
113
+ 'specified_classes' : specified_classes,
114
+ 'classes_to_remove' : classes_to_remove,
115
+ 'max_class_size' : max_class_size,
116
+ 'animal_state' : None,
117
+ 'view' : None,
118
+ 'shuffle_data' : shuffle_data,
119
+ 'test_videos' : None}
120
+ train_embeds, train_labels, train_images, _, _, _ = process_dataset_in_mem(**kwargs)
121
+ # Convert labels to numerical values
122
+ label_to_appear_first = 'other'
123
+ unique_labels = set(train_labels)
124
+ unique_labels.discard(label_to_appear_first)
125
+
126
+ label_to_index = {label_to_appear_first: 0}
127
+
128
+ label_to_index.update({label: idx + 1 for idx, label in enumerate(unique_labels)})
129
+ index_to_label = {idx: label for label, idx in label_to_index.items()}
130
+ numerical_labels = np.array([label_to_index[label] for label in train_labels])
131
+
132
+ print("Label Valence: ", label_to_index)
133
+ # Split data into train and test sets
134
+ X_train, X_test, y_train, y_test = get_train_test_split(train_embeds, numerical_labels, test_size=0.05, random_state=42)
135
+ with st.spinner("Model training in progress..."):
136
+ svm_clf = train_model(X_train, y_train)
137
+
138
+ # Predict on the test set
139
+ with st.spinner("In progress..."):
140
+ y_pred = svm_clf.predict(X_test)
141
+ accuracy = accuracy_score(y_test, y_pred)
142
+ report = classification_report(y_test, y_pred, target_names=[index_to_label[idx] for idx in range(len(label_to_index))], output_dict=True)
143
+ report_df = pd.DataFrame(report).transpose()
144
+
145
+ # save results to session state
146
+ st.session_state.svm_clf = svm_clf
147
+ st.session_state.report_df = report_df
148
+ st.session_state.accuracy = accuracy
149
+
150
+ if st.session_state.svm_clf is not None:
151
+ pickled_model = pickle_model(st.session_state.svm_clf)
152
+
153
+ st.text(f"Eval Accuracy: {st.session_state.accuracy}")
154
+ st.subheader("Classification Report:")
155
+ st.dataframe(st.session_state.report_df)
156
+
157
+ st.download_button("Download model as .pkl file",
158
+ data=pickled_model,
159
+ file_name=f"{'_'.join(specified_classes)}_classifier.pkl")