File size: 1,086 Bytes
4a7b642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5509100
6189582
4a7b642
 
 
6189582
4a7b642
 
1751d3f
 
4a7b642
 
1751d3f
 
bd1dfb3
1751d3f
4a7b642
1751d3f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# import gradio as gr
# import numpy as np
# from modelutil import create_model


# def predict_digit(image):
#     try:
#         if image == None: pass
#     except:
#         model = create_model()
#         predictions = model.predict(image.reshape(1, 28, 28))
#         return np.argmax(predictions)

# gr.Interface(
#     title="MNIST Digit Classifier by Papa Sega",
#     fn=predict_digit, 
#     inputs=gr.Sketchpad( label="Draw a digit"),
#     outputs="number",
#     live=True
# ).launch()

def load_model_weights(model, checkpoint_path):
    model.load_weights(checkpoint_path)

import gradio as gr
import numpy as np
from modelutil import create_model, load_model_weights

checkpoint_path = './checkpoint'

model = create_model()
load_model_weights(model, checkpoint_path)

def predict_digit(image):
    predictions = model.predict(image.reshape(1, 28, 28))
    return np.argmax(predictions)

gr.Interface(
    title="MNIST Digit Classifier by Papa Sega",
    fn=predict_digit, 
    inputs=gr.Sketchpad(label="Draw a digit"),
    outputs="number",
    live=True
).launch()