import tensorflow as tf from tensorflow.keras.layers import TextVectorization, Embedding, MultiHeadAttention, LayerNormalization, Dense, Dropout from tensorflow.keras.models import Model import gradio as gr import json START_TOKEN = '' END_TOKEN = '' class TransformerBlock(tf.keras.layers.Layer): def __init__(self, embed_dim, num_heads, ff_dim, rate=0.2, **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim self.num_heads = num_heads self.ff_dim = ff_dim self.rate = rate self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) self.ffn = tf.keras.Sequential([ Dense(ff_dim, activation='relu'), Dense(embed_dim), ]) self.layernorm1 = LayerNormalization(epsilon=1e-5) self.layernorm2 = LayerNormalization(epsilon=1e-5) self.dropout1 = Dropout(rate) self.dropout2 = Dropout(rate) def call(self, inputs, training=None): attn_output = self.att(inputs, inputs) attn_output = self.dropout1(attn_output, training=training) out1 = self.layernorm1(inputs + attn_output) ffn_output = self.ffn(out1) ffn_output = self.dropout2(ffn_output, training=training) return self.layernorm2(out1 + ffn_output) def get_config(self): config = super().get_config() config.update({ 'embed_dim': self.embed_dim, 'num_heads': self.num_heads, 'ff_dim': self.ff_dim, 'rate': self.rate, }) return config class TokenAndPositionEmbedding(tf.keras.layers.Layer): def __init__(self, maxlen, vocab_size, embed_dim, **kwargs): super().__init__(**kwargs) self.maxlen = maxlen self.vocab_size = vocab_size self.embed_dim = embed_dim self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim) self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim) def call(self, x): maxlen = tf.shape(x)[-1] positions = tf.range(start=0, limit=maxlen, delta=1) positions = self.pos_emb(positions) x = self.token_emb(x) return x + positions def get_config(self): config = super().get_config() config.update({ 'maxlen': self.maxlen, 'vocab_size': self.vocab_size, 'embed_dim': self.embed_dim, }) return config def load_model(filename="tg-medium"): model = tf.keras.models.load_model(f'{filename}.h5', custom_objects={ 'TokenAndPositionEmbedding': TokenAndPositionEmbedding, 'TransformerBlock': TransformerBlock }) with open(f'{filename}.json', 'r', encoding='utf-8') as f: vocab = json.load(f) vectorizer = TextVectorization( max_tokens=96000, output_sequence_length=100, standardize=None, vocabulary=vocab ) return model, vectorizer def generate_text(model, vectorizer, prompt): prompt = START_TOKEN + ' ' + prompt + ' ' + END_TOKEN input_seq = vectorizer([prompt]) input_seq = input_seq[:, :-1] predictions = model.predict(input_seq) predicted_tokens = tf.argmax(predictions[0], axis=-1) vocab = vectorizer.get_vocabulary() output_tokens = [vocab[idx] for idx in predicted_tokens.numpy()] if END_TOKEN in output_tokens: end_index = output_tokens.index(END_TOKEN) output_tokens = output_tokens[:end_index] if START_TOKEN in output_tokens: output_tokens.remove(START_TOKEN) output = ' '.join(output_tokens) return output def main(): model, vectorizer = load_model() def generate_response(prompt): return generate_text(model, vectorizer, prompt) iface = gr.Interface( fn=generate_response, inputs=gr.Textbox(lines=2, placeholder="Start your conversation."), outputs="text", title="tg-medium", description="Interference API. (russian only)" ) iface.launch() if __name__ == "__main__": main()