Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import tensorflow as tf | |
import json | |
import numpy as np | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
from tensorflow.keras.preprocessing.text import tokenizer_from_json | |
# Load the Keras model architecture from model.json | |
def load_model_architecture(model_json_path='model.json'): | |
with open(model_json_path, 'r', encoding='utf-8') as json_file: | |
model_json = json_file.read() | |
model = tf.keras.models.model_from_json(model_json) | |
return model | |
# Load the model weights from polarisatie_model.h5 | |
def load_model_weights(model, weights_path='polarisatie_model.h5'): | |
model.load_weights(weights_path) | |
return model | |
# Load the tokenizer from tokenizer.json | |
def load_tokenizer(tokenizer_path='tokenizer.json'): | |
with open(tokenizer_path, 'r', encoding='utf-8') as f: | |
tokenizer_json = f.read() | |
tokenizer = tokenizer_from_json(tokenizer_json) | |
return tokenizer | |
# Load max_length from max_length.txt | |
def load_max_length(max_length_path='max_length.txt'): | |
with open(max_length_path, 'r') as f: | |
max_length = int(f.read().strip()) | |
return max_length | |
# Preprocessing function | |
def preprocess(text, tokenizer, max_length): | |
# Tokenize the text | |
sequences = tokenizer.texts_to_sequences([text]) | |
# Pad the sequences | |
padded = pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post') | |
return padded | |
# Prediction function | |
def predict_polarization(text, model, tokenizer, max_length): | |
if not text.strip(): | |
return "Voer alstublieft een geldige zin in." | |
# Preprocess the text | |
processed_text = preprocess(text, tokenizer, max_length) | |
# Make prediction | |
prediction = model.predict(processed_text) | |
# Assume the model outputs a probability | |
is_polarizing = bool(prediction[0] > 0.5) | |
return "Ja" if is_polarizing else "Nee" | |
# Load all components | |
model = load_model_architecture() | |
model = load_model_weights(model) | |
tokenizer = load_tokenizer() | |
max_length = load_max_length() | |
# Define the Gradio interface using Blocks | |
with gr.Blocks() as demo: | |
gr.Markdown("# Polarisatie Thermometer") | |
gr.Markdown("Voer een zin in om te beoordelen of deze polariserend is.") | |
with gr.Row(): | |
input_text = gr.Textbox(label="Zin", placeholder="Schrijf hier een zin...", lines=2) | |
evaluate_button = gr.Button("Evalueren") | |
result = gr.Textbox(label="Is polariserend?", interactive=False) | |
evaluate_button.click( | |
fn=lambda text: predict_polarization(text, model, tokenizer, max_length), | |
inputs=input_text, | |
outputs=result | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch() | |