File size: 3,570 Bytes
c20f071
 
 
cc4e3d8
da72438
840fdaa
44c8341
 
 
 
50edbe9
840fdaa
f73e099
 
 
44c8341
50edbe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85f8980
 
 
 
 
 
 
 
840fdaa
f73e099
 
 
 
 
 
 
 
840fdaa
 
0fbae15
fc829e4
 
cb13d0d
fc829e4
 
0fbae15
fc829e4
 
 
44c8341
 
8b25912
47aa6b1
fc829e4
 
 
 
 
 
 
 
 
 
 
47aa6b1
8b25912
 
 
294e24f
44c8341
85f8980
44c8341
2b584be
47aa6b1
294e24f
44c8341
 
0fbae15
7a972f8
cb13d0d
44c8341
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
os.system("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu")
os.system("pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cpu.html")
# os.system("apt-get install -y graphviz-dev")
# os.system("pip install pygraphviz")

import gradio as gr
from glycowork.ml.processing import dataset_to_dataloader
import numpy as np
import torch
import torch.nn as nn
from networkx.drawing.nx_agraph import write_dot
# import pygraphviz as pgv
# from glycowork.motif.graph import glycan_to_nxGraph
# import networkx as nx

class EnsembleModel(nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = models
        
    def forward(self, data):
      # Check if GPU available
      device = "cpu"
      if torch.cuda.is_available():
        device = "cuda:0"
      # Prepare data
      x = data.labels.to(device)
      edge_index = data.edge_index.to(device)
      batch = data.batch.to(device)
      y_pred = [model(x,edge_index, batch).cpu().detach().numpy() for model in self.models]
      y_pred = np.mean(y_pred,axis=0)[0]
      return y_pred
  
class_list=['Amoebozoa', 'Animalia', 'Bacteria', 'Bamfordvirae', 'Chromista', 'Euryarchaeota', 'Excavata', 'Fungi', 'Heunggongvirae', 
            'Orthornavirae', 'Pararnavirae', 'Plantae', 'Proteoarchaeota', 'Protista', 'Riboviria']

model1 = torch.load("model1.pt", map_location=torch.device('cpu'))
model2 = torch.load("model2.pt", map_location=torch.device('cpu'))
model3 = torch.load("model3.pt", map_location=torch.device('cpu'))

def fn(glycan, model):
    # Draw graph
    # graph = glycan_to_nxGraph(glycan)
    # node_labels = nx.get_node_attributes(graph, 'string_labels')
    # labels = {i:node_labels[i] for i in range(len(graph.nodes))}
    # graph = nx.relabel_nodes(graph, labels)
    # write_dot(graph, "graph.dot")
    # graph=pgv.AGraph("graph.dot")  
    # graph.layout(prog='dot')
    # graph.draw("graph.png")
    
    # Perform inference
    if model == "No data augmentation":
      model_pred = model1
      model_pred.eval()
    elif model == "Ensemble":
      model_pred = model3
      model_pred.eval()
    else:
      model_pred = model2
      model_pred.eval()
    
    glycan = [glycan]
    label = [0]
    data = next(iter(dataset_to_dataloader(glycan, label, batch_size=1)))
    
    if model == "Ensemble":
        pred = model_pred(data)
    else:
        device = "cpu"
        x = data.labels
        edge_index = data.edge_index
        batch = data.batch
        x = x.to(device)
        edge_index = edge_index.to(device)
        batch = batch.to(device)
        pred = model_pred(x,edge_index, batch).cpu().detach().numpy()[0]
    
    pred = np.exp(pred)/sum(np.exp(pred)) # Softmax 
    pred = [float(x) for x in pred]
    pred = {class_list[i]:pred[i] for i in range(15)}
    return pred


demo = gr.Interface(
    fn=fn,
    inputs=[gr.Textbox(label="Glycan sequence"), gr.Radio(label="Model",choices=["No data augmentation", "Random node deletion", "Ensemble"])],
    outputs=[gr.Label(num_top_classes=15, label="Prediction")],
    allow_flagging=False,
    title="SweetNet demo",
    examples=[["GlcOSN(a1-4)GlcA(b1-4)GlcOSN(a1-4)GlcAOS(b1-4)GlcOSN(a1-4)GlcOSN", "No data augmentation"],
    ["Man(a1-2)Man(a1-3)[Man(a1-3)Man(a1-6)]Man(b1-4)GlcNAc(b1-4)GlcNAc", "Random node deletion"],
    ["Man(a1-2)Man(a1-3)[Man(a1-6)]Man(a1-6)[Man(a1-2)Man(a1-2)Man(a1-3)]Man(b1-4)GlcNAc", "Ensemble"]]
)
demo.launch(debug=True)