File size: 1,476 Bytes
74b29ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import pandas as pd
import torch.nn as nn

from flask import Flask, render_template, request
from model.bilstm import BiLSTM

app = Flask(__name__)

model = BiLSTM.load_from_checkpoint('checkpoints/bilstm_result/epoch=23-step=3456.ckpt', lr=1e-3, num_classes=1, input_size=12)
model.eval()
model.freeze()

@app.route('/', methods=['POST', 'GET'])
def index():
    if request.method == 'POST':
        data = {
            'tau1': [request.form.get('tau1')],
            'tau2': [request.form.get('tau2')],
            'tau3': [request.form.get('tau3')],
            'tau4': [request.form.get('tau4')],
            'p1': [request.form.get('p1')],
            'p2': [request.form.get('p2')],
            'p3': [request.form.get('p3')],
            'p4': [request.form.get('p4')],
            'g1': [request.form.get('g1')],
            'g2': [request.form.get('g2')],
            'g3': [request.form.get('g3')],
            'g4': [request.form.get('g4')]
        }

        df = pd.DataFrame(data).astype('float')
        X = torch.tensor(df.values.tolist())

        with torch.no_grad():
            preds = model(X)

        preds = nn.Sigmoid()(preds.squeeze(1))
        preds = preds.numpy()
        
        if preds > 0.5:
            result = 'Stable'

        else:
            result = 'Unstable'

        return render_template('index.html', result=result)

    return render_template('index.html')
    
if __name__ == '__main__':
    app.run(debug=True)