Stefaron commited on
Commit
52a7c13
·
verified ·
1 Parent(s): 12f297e

Upload trashclassification_adamata.py

Browse files
Files changed (1) hide show
  1. trashclassification_adamata.py +116 -0
trashclassification_adamata.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import zipfile
3
+ import io
4
+ import os
5
+ import pandas as pd
6
+ from sklearn.model_selection import train_test_split
7
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
8
+ from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
9
+ from tensorflow.keras.applications import ResNet50
10
+ from tensorflow.keras.applications.resnet50 import preprocess_input
11
+ from tensorflow.keras.optimizers import Adam
12
+ from tensorflow.keras.models import Model
13
+ from sklearn.utils.class_weight import compute_class_weight
14
+ from keras.layers import Dense, GlobalAveragePooling2D
15
+ from keras.layers import Dropout
16
+
17
+ output_dir = "./data"
18
+ url = "https://huggingface.co/datasets/garythung/trashnet/resolve/main/dataset-resized.zip"
19
+
20
+ # Mendownload file ZIP (mungkin bisa dihindari jika sudah tersedia secara lokal)
21
+ response = requests.get(url)
22
+ with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:
23
+ zip_ref.extractall(output_dir)
24
+
25
+ data_dir = './data/dataset-resized'
26
+ garbage_types = os.listdir(data_dir)
27
+
28
+ # Menggunakan pandas untuk memanipulasi file path gambar
29
+ data = []
30
+ for garbage_type in garbage_types:
31
+ garbage_type_path = os.path.join(data_dir, garbage_type)
32
+ if os.path.isdir(garbage_type_path):
33
+ for file in os.listdir(garbage_type_path):
34
+ data.append((os.path.join(garbage_type_path, file), garbage_type))
35
+
36
+ df = pd.DataFrame(data, columns=['filepath', 'label'])
37
+
38
+ # Split dataset
39
+ train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])
40
+
41
+ # Augmentasi dan generator data
42
+ train_datagen = ImageDataGenerator(
43
+ rotation_range=60,
44
+ width_shift_range=0.15,
45
+ height_shift_range=0.15,
46
+ zoom_range=0.20,
47
+ horizontal_flip=True,
48
+ vertical_flip=True,
49
+ shear_range=0.05,
50
+ brightness_range=[0.9, 1.1],
51
+ channel_shift_range=10,
52
+ fill_mode='nearest',
53
+ preprocessing_function=preprocess_input
54
+ )
55
+
56
+ val_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
57
+
58
+ train_generator = train_datagen.flow_from_dataframe(
59
+ dataframe=train_df,
60
+ x_col="filepath",
61
+ y_col="label",
62
+ target_size=(384, 384),
63
+ batch_size=32,
64
+ class_mode='categorical',
65
+ shuffle=False
66
+ )
67
+
68
+ val_generator = val_datagen.flow_from_dataframe(
69
+ dataframe=val_df,
70
+ x_col="filepath",
71
+ y_col="label",
72
+ target_size=(384, 384),
73
+ batch_size=32,
74
+ class_mode='categorical',
75
+ shuffle=False
76
+ )
77
+
78
+ class_labels = train_df['label'].unique()
79
+ class_labels
80
+
81
+ train_generator.class_indices
82
+
83
+ weights = compute_class_weight(class_weight='balanced', classes=class_labels, y=train_df['label'])
84
+
85
+ class_weights = dict(zip(train_generator.class_indices.values(), weights))
86
+
87
+
88
+ # Model
89
+ base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(384, 384, 3))
90
+
91
+ for layer in base_model.layers[:143]:
92
+ layer.trainable = False
93
+
94
+ x = base_model.output
95
+ x = GlobalAveragePooling2D()(x)
96
+ x = Dropout(0.5)(x)
97
+ x = Dense(6, activation='softmax')(x)
98
+
99
+ model = Model(inputs=base_model.input, outputs=x)
100
+ model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
101
+
102
+ # Callbacks
103
+ reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=2, min_lr=0.00001)
104
+ early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=8, restore_best_weights=True, verbose=1)
105
+ model_checkpoint = ModelCheckpoint(filepath="best_model.keras", monitor="val_loss", save_best_only=True, verbose=1)
106
+
107
+ callbacks = [reduce_lr, early_stopping, model_checkpoint]
108
+
109
+ # Model training
110
+ history = model.fit(
111
+ train_generator,
112
+ epochs=50,
113
+ validation_data=val_generator,
114
+ class_weight=class_weights,
115
+ callbacks=callbacks
116
+ )