FingerprintSOM / app.py
Overglitch's picture
Update app.py
dfbd2d8 verified
raw
history blame
2.93 kB
import pickle
from minisom import MiniSom
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import math
class InputData(BaseModel):
array: List[List[int]]
app = FastAPI()
# Cargar el modelo SOM
def load_model():
with open('som.pkl', 'rb') as fid:
somecoli = pickle.load(fid)
return somecoli
def sobel(I):
m, n = I.shape
Gx = np.zeros([m-2, n-2], np.float32)
Gy = np.zeros([m-2, n-2], np.float32)
gx = [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]
gy = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]]
for j in range(1, m-2):
for i in range(1, n-2):
Gx[j-1, i-1] = sum(sum(I[j-1:j+2, i-1:i+2] * gx))
Gy[j-1, i-1] = sum(sum(I[j-1:j+2, i-1:i+2] * gy))
return Gx, Gy
def medfilt2(G, d=3):
m, n = G.shape
temp = np.zeros([m+2*(d//2), n+2*(d//2)], np.float32)
salida = np.zeros([m, n], np.float32)
temp[1:m+1, 1:n+1] = G
for i in range(1, m):
for j in range(1, n):
A = np.asarray(temp[i-1:i+2, j-1:j+2]).reshape(-1)
salida[i-1, j-1] = np.sort(A)[d+1]
return salida
def orientacion(patron, w):
Gx, Gy = sobel(patron)
Gx = medfilt2(Gx)
Gy = medfilt2(Gy)
m, n = Gx.shape
mOrientaciones = np.zeros([m//w, n//w], np.float32)
for i in range(m//w):
for j in range(n//w):
YY = sum(sum(2*Gx[i*w:(i+1)*w, j:j+1] * Gy[i*w:(i+1)*w, j:j+1]))
XX = sum(sum(Gx[i*w:(i+1)*w, j:j+1]**2 - Gy[i*w:(i+1)*w, j:j+1]**2))
mOrientaciones[i, j] = (0.5 * math.atan2(YY, XX) + math.pi / 2.0) * (180.0 / math.pi)
return mOrientaciones
def representativo(imarray):
imarray = np.squeeze(imarray)
m, n = imarray.shape
patron = imarray[1:m-1, 1:n-1]
EE = orientacion(patron, 14)
return np.asarray(EE).reshape(-1)
som = load_model()
MM = np.array([
[ 0., -1., -1., -1., -1., 2., -1., -1., -1., 3.],
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
[-1., -1., -1., 1., -1., -1., -1., -1., -1., -1.],
[ 1., -1., -1., -1., -1., -1., -1., -1., -1., 0.],
[-1., -1., -1., -1., 1., -1., -1., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
[ 3., -1., -1., -1., -1., -1., -1., -1., -1., 3.],
[-1., -1., -1., 0., -1., -1., 3., -1., -1., -1.],
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.],
[ 2., -1., -1., -1., 1., -1., -1., -1., -1., 2.]
])
@app.post("/predict/")
async def predict(data: InputData):
try:
input_data = np.array(data.array).reshape(256, 256, 1)
representative_data = representativo(input_data)
representative_data = representative_data.reshape(1, -1)
w = som.winner(representative_data)
prediction = MM[w]
return {"prediction": prediction}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))