File size: 4,062 Bytes
9116ab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30363e7
9116ab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import tensorflow as tf
from tensorflow.keras.layers import TextVectorization, Embedding, MultiHeadAttention, LayerNormalization, Dense, Dropout
from tensorflow.keras.models import Model
import gradio as gr
import json

START_TOKEN = '<start>'
END_TOKEN = '<end>'

class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.2, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.rate = rate

        self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = tf.keras.Sequential([
            Dense(ff_dim, activation='relu'),
            Dense(embed_dim),
        ])

        self.layernorm1 = LayerNormalization(epsilon=1e-5)
        self.layernorm2 = LayerNormalization(epsilon=1e-5)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)

    def call(self, inputs, training=None):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

    def get_config(self):
        config = super().get_config()
        config.update({
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads,
            'ff_dim': self.ff_dim,
            'rate': self.rate,
        })
        return config

class TokenAndPositionEmbedding(tf.keras.layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.maxlen = maxlen
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

        self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions

    def get_config(self):
        config = super().get_config()
        config.update({
            'maxlen': self.maxlen,
            'vocab_size': self.vocab_size,
            'embed_dim': self.embed_dim,
        })
        return config

def load_model(filename="tg-medium"):
    model = tf.keras.models.load_model(f'{filename}.h5', custom_objects={
        'TokenAndPositionEmbedding': TokenAndPositionEmbedding,
        'TransformerBlock': TransformerBlock
    })
    with open(f'{filename}.json', 'r', encoding='utf-8') as f:
        vocab = json.load(f)
    vectorizer = TextVectorization(
        max_tokens=128000,
        output_sequence_length=100,
        standardize=None,
        vocabulary=vocab
    )
    return model, vectorizer

def generate_text(model, vectorizer, prompt):
    prompt = START_TOKEN + ' ' + prompt + ' ' + END_TOKEN
    input_seq = vectorizer([prompt])
    input_seq = input_seq[:, :-1]
    predictions = model.predict(input_seq)
    predicted_tokens = tf.argmax(predictions[0], axis=-1)
    vocab = vectorizer.get_vocabulary()
    output_tokens = [vocab[idx] for idx in predicted_tokens.numpy()]
    if END_TOKEN in output_tokens:
        end_index = output_tokens.index(END_TOKEN)
        output_tokens = output_tokens[:end_index]
    if START_TOKEN in output_tokens:
        output_tokens.remove(START_TOKEN)
    output = ' '.join(output_tokens)
    return output

def main():
    model, vectorizer = load_model()

    def generate_response(prompt):
        return generate_text(model, vectorizer, prompt)

    iface = gr.Interface(
        fn=generate_response,
        inputs=gr.Textbox(lines=2, placeholder="Start your conversation."),
        outputs="text",
        title="tg-medium",
        description="Interference API. (russian only)"
    )

    iface.launch()

if __name__ == "__main__":
    main()