Spaces:
Running
Running
import tensorflow as tf | |
from tensorflow.keras.layers import TextVectorization, Embedding, MultiHeadAttention, LayerNormalization, Dense, Dropout | |
from tensorflow.keras.models import Model | |
import gradio as gr | |
import json | |
START_TOKEN = '<start>' | |
END_TOKEN = '<end>' | |
class TransformerBlock(tf.keras.layers.Layer): | |
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.2, **kwargs): | |
super().__init__(**kwargs) | |
self.embed_dim = embed_dim | |
self.num_heads = num_heads | |
self.ff_dim = ff_dim | |
self.rate = rate | |
self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) | |
self.ffn = tf.keras.Sequential([ | |
Dense(ff_dim, activation='relu'), | |
Dense(embed_dim), | |
]) | |
self.layernorm1 = LayerNormalization(epsilon=1e-5) | |
self.layernorm2 = LayerNormalization(epsilon=1e-5) | |
self.dropout1 = Dropout(rate) | |
self.dropout2 = Dropout(rate) | |
def call(self, inputs, training=None): | |
attn_output = self.att(inputs, inputs) | |
attn_output = self.dropout1(attn_output, training=training) | |
out1 = self.layernorm1(inputs + attn_output) | |
ffn_output = self.ffn(out1) | |
ffn_output = self.dropout2(ffn_output, training=training) | |
return self.layernorm2(out1 + ffn_output) | |
def get_config(self): | |
config = super().get_config() | |
config.update({ | |
'embed_dim': self.embed_dim, | |
'num_heads': self.num_heads, | |
'ff_dim': self.ff_dim, | |
'rate': self.rate, | |
}) | |
return config | |
class TokenAndPositionEmbedding(tf.keras.layers.Layer): | |
def __init__(self, maxlen, vocab_size, embed_dim, **kwargs): | |
super().__init__(**kwargs) | |
self.maxlen = maxlen | |
self.vocab_size = vocab_size | |
self.embed_dim = embed_dim | |
self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim) | |
self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim) | |
def call(self, x): | |
maxlen = tf.shape(x)[-1] | |
positions = tf.range(start=0, limit=maxlen, delta=1) | |
positions = self.pos_emb(positions) | |
x = self.token_emb(x) | |
return x + positions | |
def get_config(self): | |
config = super().get_config() | |
config.update({ | |
'maxlen': self.maxlen, | |
'vocab_size': self.vocab_size, | |
'embed_dim': self.embed_dim, | |
}) | |
return config | |
def load_model(filename="tg-medium"): | |
model = tf.keras.models.load_model(f'{filename}.h5', custom_objects={ | |
'TokenAndPositionEmbedding': TokenAndPositionEmbedding, | |
'TransformerBlock': TransformerBlock | |
}) | |
with open(f'{filename}.json', 'r', encoding='utf-8') as f: | |
vocab = json.load(f) | |
vectorizer = TextVectorization( | |
max_tokens=128000, | |
output_sequence_length=100, | |
standardize=None, | |
vocabulary=vocab | |
) | |
return model, vectorizer | |
def generate_text(model, vectorizer, prompt): | |
prompt = START_TOKEN + ' ' + prompt + ' ' + END_TOKEN | |
input_seq = vectorizer([prompt]) | |
input_seq = input_seq[:, :-1] | |
predictions = model.predict(input_seq) | |
predicted_tokens = tf.argmax(predictions[0], axis=-1) | |
vocab = vectorizer.get_vocabulary() | |
output_tokens = [vocab[idx] for idx in predicted_tokens.numpy()] | |
if END_TOKEN in output_tokens: | |
end_index = output_tokens.index(END_TOKEN) | |
output_tokens = output_tokens[:end_index] | |
if START_TOKEN in output_tokens: | |
output_tokens.remove(START_TOKEN) | |
output = ' '.join(output_tokens) | |
return output | |
def main(): | |
model, vectorizer = load_model() | |
def generate_response(prompt): | |
return generate_text(model, vectorizer, prompt) | |
iface = gr.Interface( | |
fn=generate_response, | |
inputs=gr.Textbox(lines=2, placeholder="Start your conversation."), | |
outputs="text", | |
title="tg-medium", | |
description="Interference API. (russian only)" | |
) | |
iface.launch() | |
if __name__ == "__main__": | |
main() |