Spaces:
Build error
Build error
File size: 4,395 Bytes
0508a3e 498eb73 0508a3e 498eb73 0508a3e c0b3079 0508a3e 6fedb61 0508a3e ffe82b4 0508a3e 6fedb61 0508a3e 6fedb61 0508a3e 442b516 0508a3e 4a892b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import torch
from torch import nn
import gradio as gr
import requests
from PIL import Image
import ast
import json
import numpy as np
import torch.nn.functional as F
class Net(nn.Module):
"""Custom network predicting the next character of a string.
Parameters
----------
vocab_size : int
The number of characters in the vocabulary.
embedding_dim : int
Dimension of the character embedding vectors.
dense_dim : int
Number of neurons in the linear layer that follows the LSTM.
hidden_dim : int
Size of the LSTM hidden state.
max_norm : int
If any of the embedding vectors has a higher L2 norm than `max_norm`
it is rescaled.
n_layers : int
Number of the layers of the LSTM.
"""
def __init__(
self,
vocab_size,
embedding_dim=2,
norm_type=2,
max_norm=2,
window_size=3,
dense_dim=32,
hidden_dim=8,
n_layers=1
):
super().__init__()
self.embedding = nn.Embedding(
vocab_size,
embedding_dim,
padding_idx=vocab_size - 1,
norm_type=2,
max_norm=max_norm,
)
self.lstm = nn.LSTM(
embedding_dim, hidden_dim, batch_first=True, num_layers=n_layers
)
self.linear_1 = nn.Linear(hidden_dim, dense_dim)
self.linear_2 = nn.Linear(dense_dim, vocab_size)
def forward(self, x, h=None, c=None):
"""Run the forward pass.
Parameters
----------
x : torch.Tensor
Input tensor of shape `(n_samples, window_size)` of dtype
`torch.int64`.
h, c : torch.Tensor or None
Hidden states of the LSTM.
Returns
-------
logits : torch.Tensor
Tensor of shape `(n_samples, vocab_size)`.
h, c : torch.Tensor or None
Hidden states of the LSTM.
"""
emb = self.embedding(x) # (n_samples, window_size, embedding_dim)
if h is not None and c is not None:
_, (h, c) = self.lstm(emb, (h, c))
else:
_, (h, c) = self.lstm(emb) # (n_layers, n_samples, hidden_dim)
h_mean = h.mean(dim=0) # (n_samples, hidden_dim)
x = self.linear_1(h_mean) # (n_samples, dense_dim)
logits = self.linear_2(x) # (n_samples, vocab_size)
return logits, h, c
with open("setup_config.json") as f:
setup_config = json.load(f)
with open("ch2ix.json") as f:
ch2ix = json.load(f)
with open("vocab.json") as f:
vocab = json.load(f)
with open("code_to_text.json") as f:
code_to_text = json.load(f)
with open("text_to_code.json") as f:
text_to_code = json.load(f)
window_size = int(setup_config["window_size"])
vocab_size = int(setup_config["vocab_size"])
embedding_dim = int(setup_config["embedding_dim"])
norm_type = int(setup_config["norm_type"])
max_norm = int(setup_config["max_norm"])
model = Net( vocab_size, embedding_dim, norm_type, max_norm, window_size)
model.load_state_dict(torch.load("model_lstm.pth"))
model.eval()
def process_text(text):
raw_text = text.split(',')
initial_text = [text_to_code[item] for item in raw_text if item in text_to_code]
window_size = int(setup_config["window_size"])
res = []
for i in range(5):
initial_text = initial_text[-window_size:]
features = torch.LongTensor([[ch2ix[c] if c in ch2ix else (int(setup_config["vocab_size"]) - 1) for c in initial_text]])
logits, h, c = model(features)
probs = F.softmax(logits[0], dim=0).detach().numpy()
new_ch = np.random.choice(vocab, p=probs)
initial_text.append(new_ch)
res.append(new_ch)
res = [code_to_text[item] for item in res if item in code_to_text]
return res
title = "Interactive demo: LSTM based item recomendation"
description = "LSTM based item recomendation."
article = "<p style='text-align: center'></p>"
iface = gr.Interface(fn=process_text,
inputs=gr.Textbox(placeholder="Audio-Visual, Electrical, Heating, Interior Shutters, Low Voltage"),
outputs="text",
title=title,
description=description,
article=article
)
iface.launch(debug=False) |