File size: 1,438 Bytes
70b5fc5
86814fc
70b5fc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86814fc
70b5fc5
 
 
86814fc
70b5fc5
 
 
86814fc
 
70b5fc5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import transformers
import hscommon

import gradio as gr
import tensorflow as tf

from official.nlp import optimization  # to create AdamW optimizer

MODEL_DIRECTORY = 'save/modelV1'
PRETRAINED_MODEL_NAME = 'dbmdz/bert-base-german-cased'
TOKENIZER = transformers.BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)
MAX_SEQUENCE_LENGTH = 256
EPOCHS = 2
OPTIMIZER = 'adamw'
INIT_LR = 3e-5
LOSS = tf.keras.losses.BinaryCrossentropy(from_logits=False)
METRICS = tf.metrics.BinaryAccuracy()

def compile_model(model):
    steps_per_epoch = 10
    num_train_steps = steps_per_epoch * EPOCHS
    num_warmup_steps = int(0.1*num_train_steps)

    optimizer = optimization.create_optimizer(
        init_lr=INIT_LR,
        num_train_steps=steps_per_epoch,
        num_warmup_steps=num_warmup_steps,
        optimizer_type=OPTIMIZER
    )

    model.compile(optimizer=optimizer, loss=LOSS, metrics=[METRICS])
    return model

hs_detection_model = tf.keras.models.load_model(MODEL_DIRECTORY, compile=False)
compile_model(hs_detection_model)

def inference(sentence):
    encoded_sentence = hscommon.encode([sentence], TOKENIZER, MAX_SEQUENCE_LENGTH)
    predicition = hs_detection_model.predict(encoded_sentence.values())
    return predicition

input_sentence_text = gr.inputs.Textbox(placeholder="Hier den Satz eingeben, der Hassrede enthalten kann.")
iface = gr.Interface(fn=inference, inputs=input_sentence_text, outputs="text")
iface.launch()