tg_medium / app.py
AILaborant's picture
Create app.py
9116ab5 verified
raw
history blame
4.06 kB
import tensorflow as tf
from tensorflow.keras.layers import TextVectorization, Embedding, MultiHeadAttention, LayerNormalization, Dense, Dropout
from tensorflow.keras.models import Model
import gradio as gr
import json
START_TOKEN = '<start>'
END_TOKEN = '<end>'
class TransformerBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.2, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.ff_dim = ff_dim
self.rate = rate
self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
self.ffn = tf.keras.Sequential([
Dense(ff_dim, activation='relu'),
Dense(embed_dim),
])
self.layernorm1 = LayerNormalization(epsilon=1e-5)
self.layernorm2 = LayerNormalization(epsilon=1e-5)
self.dropout1 = Dropout(rate)
self.dropout2 = Dropout(rate)
def call(self, inputs, training=None):
attn_output = self.att(inputs, inputs)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(inputs + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
return self.layernorm2(out1 + ffn_output)
def get_config(self):
config = super().get_config()
config.update({
'embed_dim': self.embed_dim,
'num_heads': self.num_heads,
'ff_dim': self.ff_dim,
'rate': self.rate,
})
return config
class TokenAndPositionEmbedding(tf.keras.layers.Layer):
def __init__(self, maxlen, vocab_size, embed_dim, **kwargs):
super().__init__(**kwargs)
self.maxlen = maxlen
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim)
self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim)
def call(self, x):
maxlen = tf.shape(x)[-1]
positions = tf.range(start=0, limit=maxlen, delta=1)
positions = self.pos_emb(positions)
x = self.token_emb(x)
return x + positions
def get_config(self):
config = super().get_config()
config.update({
'maxlen': self.maxlen,
'vocab_size': self.vocab_size,
'embed_dim': self.embed_dim,
})
return config
def load_model(filename="tg-medium"):
model = tf.keras.models.load_model(f'{filename}.h5', custom_objects={
'TokenAndPositionEmbedding': TokenAndPositionEmbedding,
'TransformerBlock': TransformerBlock
})
with open(f'{filename}.json', 'r', encoding='utf-8') as f:
vocab = json.load(f)
vectorizer = TextVectorization(
max_tokens=96000,
output_sequence_length=100,
standardize=None,
vocabulary=vocab
)
return model, vectorizer
def generate_text(model, vectorizer, prompt):
prompt = START_TOKEN + ' ' + prompt + ' ' + END_TOKEN
input_seq = vectorizer([prompt])
input_seq = input_seq[:, :-1]
predictions = model.predict(input_seq)
predicted_tokens = tf.argmax(predictions[0], axis=-1)
vocab = vectorizer.get_vocabulary()
output_tokens = [vocab[idx] for idx in predicted_tokens.numpy()]
if END_TOKEN in output_tokens:
end_index = output_tokens.index(END_TOKEN)
output_tokens = output_tokens[:end_index]
if START_TOKEN in output_tokens:
output_tokens.remove(START_TOKEN)
output = ' '.join(output_tokens)
return output
def main():
model, vectorizer = load_model()
def generate_response(prompt):
return generate_text(model, vectorizer, prompt)
iface = gr.Interface(
fn=generate_response,
inputs=gr.Textbox(lines=2, placeholder="Start your conversation."),
outputs="text",
title="tg-medium",
description="Interference API. (russian only)"
)
iface.launch()
if __name__ == "__main__":
main()