Spaces:
Runtime error
Runtime error
import transformers | |
import hscommon | |
import gradio as gr | |
import tensorflow as tf | |
from official.nlp import optimization # to create AdamW optimizer | |
MODEL_DIRECTORY = 'save/modelV1' | |
PRETRAINED_MODEL_NAME = 'dbmdz/bert-base-german-cased' | |
TOKENIZER = transformers.BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME) | |
MAX_SEQUENCE_LENGTH = 256 | |
EPOCHS = 2 | |
OPTIMIZER = 'adamw' | |
INIT_LR = 3e-5 | |
LOSS = tf.keras.losses.BinaryCrossentropy(from_logits=False) | |
METRICS = tf.metrics.BinaryAccuracy() | |
def compile_model(model): | |
steps_per_epoch = 10 | |
num_train_steps = steps_per_epoch * EPOCHS | |
num_warmup_steps = int(0.1*num_train_steps) | |
optimizer = optimization.create_optimizer( | |
init_lr=INIT_LR, | |
num_train_steps=steps_per_epoch, | |
num_warmup_steps=num_warmup_steps, | |
optimizer_type=OPTIMIZER | |
) | |
model.compile(optimizer=optimizer, loss=LOSS, metrics=[METRICS]) | |
return model | |
hs_detection_model = tf.keras.models.load_model(MODEL_DIRECTORY, compile=False) | |
compile_model(hs_detection_model) | |
def inference(sentence): | |
encoded_sentence = hscommon.encode([sentence], TOKENIZER, MAX_SEQUENCE_LENGTH) | |
predicition = hs_detection_model.predict(encoded_sentence.values()) | |
return predicition | |
input_sentence_text = gr.inputs.Textbox(placeholder="Hier den Satz eingeben, der Hassrede enthalten kann.") | |
iface = gr.Interface(fn=inference, inputs=input_sentence_text, outputs="text") | |
iface.launch() |