hatespeech / app.py
hm-auch
update init app, intro first common file for refactoring
86814fc
raw
history blame
1.44 kB
import transformers
import hscommon
import gradio as gr
import tensorflow as tf
from official.nlp import optimization # to create AdamW optimizer
MODEL_DIRECTORY = 'save/modelV1'
PRETRAINED_MODEL_NAME = 'dbmdz/bert-base-german-cased'
TOKENIZER = transformers.BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)
MAX_SEQUENCE_LENGTH = 256
EPOCHS = 2
OPTIMIZER = 'adamw'
INIT_LR = 3e-5
LOSS = tf.keras.losses.BinaryCrossentropy(from_logits=False)
METRICS = tf.metrics.BinaryAccuracy()
def compile_model(model):
steps_per_epoch = 10
num_train_steps = steps_per_epoch * EPOCHS
num_warmup_steps = int(0.1*num_train_steps)
optimizer = optimization.create_optimizer(
init_lr=INIT_LR,
num_train_steps=steps_per_epoch,
num_warmup_steps=num_warmup_steps,
optimizer_type=OPTIMIZER
)
model.compile(optimizer=optimizer, loss=LOSS, metrics=[METRICS])
return model
hs_detection_model = tf.keras.models.load_model(MODEL_DIRECTORY, compile=False)
compile_model(hs_detection_model)
def inference(sentence):
encoded_sentence = hscommon.encode([sentence], TOKENIZER, MAX_SEQUENCE_LENGTH)
predicition = hs_detection_model.predict(encoded_sentence.values())
return predicition
input_sentence_text = gr.inputs.Textbox(placeholder="Hier den Satz eingeben, der Hassrede enthalten kann.")
iface = gr.Interface(fn=inference, inputs=input_sentence_text, outputs="text")
iface.launch()