AILaborant commited on
Commit
9116ab5
·
verified ·
1 Parent(s): dfc3e99

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.layers import TextVectorization, Embedding, MultiHeadAttention, LayerNormalization, Dense, Dropout
3
+ from tensorflow.keras.models import Model
4
+ import gradio as gr
5
+ import json
6
+
7
+ START_TOKEN = '<start>'
8
+ END_TOKEN = '<end>'
9
+
10
+ class TransformerBlock(tf.keras.layers.Layer):
11
+ def __init__(self, embed_dim, num_heads, ff_dim, rate=0.2, **kwargs):
12
+ super().__init__(**kwargs)
13
+ self.embed_dim = embed_dim
14
+ self.num_heads = num_heads
15
+ self.ff_dim = ff_dim
16
+ self.rate = rate
17
+
18
+ self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
19
+ self.ffn = tf.keras.Sequential([
20
+ Dense(ff_dim, activation='relu'),
21
+ Dense(embed_dim),
22
+ ])
23
+
24
+ self.layernorm1 = LayerNormalization(epsilon=1e-5)
25
+ self.layernorm2 = LayerNormalization(epsilon=1e-5)
26
+ self.dropout1 = Dropout(rate)
27
+ self.dropout2 = Dropout(rate)
28
+
29
+ def call(self, inputs, training=None):
30
+ attn_output = self.att(inputs, inputs)
31
+ attn_output = self.dropout1(attn_output, training=training)
32
+ out1 = self.layernorm1(inputs + attn_output)
33
+
34
+ ffn_output = self.ffn(out1)
35
+ ffn_output = self.dropout2(ffn_output, training=training)
36
+ return self.layernorm2(out1 + ffn_output)
37
+
38
+ def get_config(self):
39
+ config = super().get_config()
40
+ config.update({
41
+ 'embed_dim': self.embed_dim,
42
+ 'num_heads': self.num_heads,
43
+ 'ff_dim': self.ff_dim,
44
+ 'rate': self.rate,
45
+ })
46
+ return config
47
+
48
+ class TokenAndPositionEmbedding(tf.keras.layers.Layer):
49
+ def __init__(self, maxlen, vocab_size, embed_dim, **kwargs):
50
+ super().__init__(**kwargs)
51
+ self.maxlen = maxlen
52
+ self.vocab_size = vocab_size
53
+ self.embed_dim = embed_dim
54
+
55
+ self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim)
56
+ self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim)
57
+
58
+ def call(self, x):
59
+ maxlen = tf.shape(x)[-1]
60
+ positions = tf.range(start=0, limit=maxlen, delta=1)
61
+ positions = self.pos_emb(positions)
62
+ x = self.token_emb(x)
63
+ return x + positions
64
+
65
+ def get_config(self):
66
+ config = super().get_config()
67
+ config.update({
68
+ 'maxlen': self.maxlen,
69
+ 'vocab_size': self.vocab_size,
70
+ 'embed_dim': self.embed_dim,
71
+ })
72
+ return config
73
+
74
+ def load_model(filename="tg-medium"):
75
+ model = tf.keras.models.load_model(f'{filename}.h5', custom_objects={
76
+ 'TokenAndPositionEmbedding': TokenAndPositionEmbedding,
77
+ 'TransformerBlock': TransformerBlock
78
+ })
79
+ with open(f'{filename}.json', 'r', encoding='utf-8') as f:
80
+ vocab = json.load(f)
81
+ vectorizer = TextVectorization(
82
+ max_tokens=96000,
83
+ output_sequence_length=100,
84
+ standardize=None,
85
+ vocabulary=vocab
86
+ )
87
+ return model, vectorizer
88
+
89
+ def generate_text(model, vectorizer, prompt):
90
+ prompt = START_TOKEN + ' ' + prompt + ' ' + END_TOKEN
91
+ input_seq = vectorizer([prompt])
92
+ input_seq = input_seq[:, :-1]
93
+ predictions = model.predict(input_seq)
94
+ predicted_tokens = tf.argmax(predictions[0], axis=-1)
95
+ vocab = vectorizer.get_vocabulary()
96
+ output_tokens = [vocab[idx] for idx in predicted_tokens.numpy()]
97
+ if END_TOKEN in output_tokens:
98
+ end_index = output_tokens.index(END_TOKEN)
99
+ output_tokens = output_tokens[:end_index]
100
+ if START_TOKEN in output_tokens:
101
+ output_tokens.remove(START_TOKEN)
102
+ output = ' '.join(output_tokens)
103
+ return output
104
+
105
+ def main():
106
+ model, vectorizer = load_model()
107
+
108
+ def generate_response(prompt):
109
+ return generate_text(model, vectorizer, prompt)
110
+
111
+ iface = gr.Interface(
112
+ fn=generate_response,
113
+ inputs=gr.Textbox(lines=2, placeholder="Start your conversation."),
114
+ outputs="text",
115
+ title="tg-medium",
116
+ description="Interference API. (russian only)"
117
+ )
118
+
119
+ iface.launch()
120
+
121
+ if __name__ == "__main__":
122
+ main()