Lambang commited on
Commit
2c2434b
·
1 Parent(s): 364ca9d
Files changed (1) hide show
  1. main.py +224 -2
main.py CHANGED
@@ -2,9 +2,57 @@ from fastapi import FastAPI
2
  import pickle
3
  import uvicorn
4
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  app = FastAPI()
7
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  #Endpoints
9
  #Root endpoints
10
  @app.get("/")
@@ -13,5 +61,179 @@ async def root():
13
  ngrok_url = "Tidak Ada URL Publik (ngrok belum selesai memulai)"
14
 
15
  return {"message": "Hello, World!", "ngrok_url": ngrok_url}
16
-
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import pickle
3
  import uvicorn
4
  import pandas as pd
5
+ import shutil
6
+ import cv2
7
+ import mediapipe as mp
8
+ from werkzeug.utils import secure_filename
9
+ import tensorflow as tf
10
+ import os
11
+ from flask import Flask, jsonify, request, flash, redirect, url_for
12
+ from pyngrok import ngrok
13
+ from fastapi import FastAPI, HTTPException, File, UploadFile, Request
14
+ from fastapi.staticfiles import StaticFiles
15
+ from fastapi.responses import JSONResponse
16
+ from pydantic import BaseModel
17
+ import subprocess
18
+
19
+ from file_processing import FileProcess
20
+ from get_load_data import GetLoadData
21
+ from data_preprocess import DataProcessing
22
+ from train_pred import TrainPred
23
 
24
  app = FastAPI()
25
+ public_url = "https://lambang0902-test-space.hf.space"
26
+ app.mount("/static", StaticFiles(directory="static"), name="static")
27
+ # Tempat deklarasi variabel-variabel penting
28
+ filepath = ""
29
+ list_class = ['Diamond','Oblong','Oval','Round','Square','Triangle']
30
+ list_folder = ['Training', 'Testing']
31
+ face_crop_img = True
32
+ face_landmark_img = True
33
+ landmark_extraction_img = True
34
+ # #-----------------------------------------------------
35
+ #
36
+ #
37
+ # #-----------------------------------------------------
38
+ # Tempat deklarasi model dan sejenisnya
39
+ selected_model = tf.keras.models.load_model(f'models/fc_model_1.h5', compile=False)
40
+ face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_alt2.xml')
41
+ mp_drawing = mp.solutions.drawing_utils
42
+ mp_face_mesh = mp.solutions.face_mesh
43
+ drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1)
44
+ # #-----------------------------------------------------
45
+ #
46
+ #
47
+ # #-----------------------------------------------------
48
+ # Tempat setting server
49
+ UPLOAD_FOLDER = './upload'
50
+ UPLOAD_MODEL = './models'
51
+ ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg','zip','h5'}
52
+ # app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
53
+ # app.config['UPLOAD_MODEL'] = UPLOAD_MODEL
54
+ # app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 # 500 MB
55
+ # #-----------------------------------------------------
56
  #Endpoints
57
  #Root endpoints
58
  @app.get("/")
 
61
  ngrok_url = "Tidak Ada URL Publik (ngrok belum selesai memulai)"
62
 
63
  return {"message": "Hello, World!", "ngrok_url": ngrok_url}
 
64
 
65
+
66
+ # #-----------------------------------------------------
67
+ #
68
+ data_processor = DataProcessing()
69
+ data_train_pred = TrainPred()
70
+ #
71
+ import random
72
+ def preprocessing(filepath):
73
+ folder_path = './static/temporary'
74
+
75
+ shutil.rmtree(folder_path)
76
+ os.mkdir(folder_path)
77
+
78
+ # data_processor.detect_landmark(data_processor.face_cropping_pred(filepath))
79
+ data_processor.enhance_contrast_histeq(data_processor.face_cropping_pred(filepath))
80
+
81
+ files = os.listdir(folder_path)
82
+ index = 0
83
+ for file_name in files:
84
+ file_ext = os.path.splitext(file_name)[1]
85
+ new_file_name = str(index) + "_" + str(random.randint(1, 100000)) + file_ext
86
+ os.rename(os.path.join(folder_path, file_name), os.path.join(folder_path, new_file_name))
87
+ index += 1
88
+
89
+ print("Tungu sampai selesaiii")
90
+
91
+ train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1 / 255.)
92
+ test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1 / 255.)
93
+
94
+ ## -------------------------------------------------------------------------
95
+ ## API UNTUK MELAKUKAN PROSES PREDIKSI
96
+ ## -------------------------------------------------------------------------
97
+
98
+ @app.post('/upload/file',tags=["Predicting"])
99
+ async def upload_file(picture: UploadFile):
100
+ file_extension = picture.filename.split('.')[-1].lower()
101
+
102
+
103
+ if file_extension not in ALLOWED_EXTENSIONS:
104
+ raise HTTPException(status_code=400, detail='Invalid file extension')
105
+
106
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
107
+ file_path = os.path.join(UPLOAD_FOLDER, secure_filename(picture.filename))
108
+ with open(file_path, 'wb') as f:
109
+ f.write(picture.file.read())
110
+ try:
111
+ processed_img = preprocessing(cv2.imread(file_path))
112
+ except Exception as e:
113
+ os.remove(file_path)
114
+ raise HTTPException(status_code=500, detail=f'Error processing image: {str(e)}')
115
+
116
+ return JSONResponse(content={'message': 'File successfully uploaded'}, status_code=200)
117
+
118
+ @app.get('/get_images', tags=["Predicting"])
119
+ def get_images():
120
+ folder_path = "./static/temporary"
121
+ files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
122
+ urls = []
123
+ for i in range(0, 3):
124
+ url = f'{public_url}/static/temporary/{files[i]}'
125
+ urls.append(url)
126
+ bentuk, persentase = data_train_pred.prediction(selected_model)
127
+ return {'urls': urls, 'bentuk_wajah':bentuk[0], 'persen':persentase}
128
+
129
+
130
+ ## -------------------------------------------------------------------------
131
+ ## API UNTUK MELAKUKAN PROSES TRAINING
132
+ ## -------------------------------------------------------------------------
133
+
134
+ # Model pydantic untuk validasi body
135
+ class TrainingParams(BaseModel):
136
+ optimizer: str
137
+ epoch: int
138
+ batchSize: int
139
+
140
+ @app.post('/upload/dataset', tags=["Training"])
141
+ async def upload_data(dataset: UploadFile):
142
+ if dataset.filename == '':
143
+ raise HTTPException(status_code=400, detail='No file selected for uploading')
144
+
145
+ # Buat path lengkap untuk menyimpan file
146
+ file_path = os.path.join(UPLOAD_FOLDER, dataset.filename)
147
+
148
+ # Simpan file ke folder yang ditentukan
149
+ with open(file_path, "wb") as file_object:
150
+ file_object.write(dataset.file.read())
151
+
152
+ # Panggil fungsi untuk mengekstrak file jika perlu
153
+ FileProcess.extract_zip(file_path)
154
+
155
+ return {'message': 'File successfully uploaded'}
156
+
157
+ @app.post('/set_params', tags=["Training"])
158
+ async def set_params(request: Request, params: TrainingParams):
159
+ global optimizer, epoch, batch_size
160
+
161
+ optimizer = params.optimizer
162
+ epoch = params.epoch
163
+ batch_size = params.batchSize
164
+
165
+ response = {'message': 'Set parameter sukses'}
166
+ return response
167
+
168
+ @app.get('/get_info_data', tags=["Training"])
169
+ def get_info_prepro():
170
+ global optimizer, epoch, batch_size
171
+ training_counts = GetLoadData.get_training_file_counts().json
172
+ testing_counts = GetLoadData.get_testing_file_counts().json
173
+ response = {
174
+ "optimizer": optimizer,
175
+ "epoch": epoch,
176
+ "batch_size": batch_size,
177
+ "training_counts": training_counts,
178
+ "testing_counts": testing_counts
179
+ }
180
+ return response
181
+
182
+ @app.get('/get_images_preprocess', tags=["Training"])
183
+ def get_random_images_crop():
184
+ images_face_landmark = GetLoadData.get_random_images(tahap="Face Landmark",public_url=public_url)
185
+ images_face_extraction = GetLoadData.get_random_images(tahap="landmark Extraction", public_url=public_url)
186
+
187
+ response = {
188
+ "face_landmark": images_face_landmark,
189
+ "landmark_extraction": images_face_extraction
190
+ }
191
+ return response
192
+
193
+ @app.get('/do_preprocessing', tags=["Training"])
194
+ async def do_preprocessing():
195
+ try:
196
+ data_train_pred.do_pre1(test="")
197
+ data_train_pred.do_pre2(test="")
198
+ return {'message': 'Preprocessing sukses'}
199
+ except Exception as e:
200
+ # Tangani kesalahan dan kembalikan respons kesalahan
201
+ error_message = f'Error during preprocessing: {str(e)}'
202
+ raise HTTPException(status_code=500, detail=error_message)
203
+
204
+ @app.get('/do_training', tags=["Training"])
205
+ def do_training():
206
+ global epoch
207
+ folder = ""
208
+ if (face_landmark_img == True and landmark_extraction_img == True):
209
+ folder = "Landmark Extraction"
210
+ elif (face_landmark_img == True and landmark_extraction_img == False):
211
+ folder = "Face Landmark"
212
+ # --------------------------------------------------------------
213
+ train_dataset_path = f"./static/dataset/{folder}/Training/"
214
+ test_dataset_path = f"./static/dataset/{folder}/Testing/"
215
+
216
+ train_image_df, test_image_df = GetLoadData.load_image_dataset(train_dataset_path, test_dataset_path)
217
+
218
+ train_gen, test_gen = data_train_pred.data_configuration(train_image_df, test_image_df)
219
+ model = data_train_pred.model_architecture()
220
+
221
+ result = data_train_pred.train_model(model, train_gen, test_gen, epoch)
222
+
223
+ # Mengambil nilai akurasi training dan validation dari objek result
224
+ train_acc = result.history['accuracy'][-1]
225
+ val_acc = result.history['val_accuracy'][-1]
226
+
227
+ # Plot accuracy
228
+ data_train_pred.plot_accuracy(result=result, epoch=epoch)
229
+ acc_url = f'{public_url}/static/accuracy_plot.png'
230
+
231
+ # Plot loss
232
+ data_train_pred.plot_loss(result=result, epoch=epoch)
233
+ loss_url = f'{public_url}/static/loss_plot.png'
234
+
235
+ # Confusion Matrix
236
+ data_train_pred.plot_confusion_matrix(model, test_gen)
237
+ conf_url = f'{public_url}/static/confusion_matrix.png'
238
+
239
+ return jsonify({'train_acc': train_acc, 'val_acc': val_acc, 'plot_acc': acc_url, 'plot_loss':loss_url,'conf':conf_url})