Spaces:
Sleeping
Sleeping
File size: 2,012 Bytes
01ad777 f541ec2 01ad777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import tensorflow as tf
import joblib
import numpy as np
from huggingface_hub import hf_hub_download
# Load the model and tokenizer from Hugging Face Hub
model_path = hf_hub_download(repo_id="rio3210/amharic-hate-speech-using-rnn-bidirectional", filename="amharic_hate_speech_rnn_model.keras")
tokenizer_path = hf_hub_download(repo_id="rio3210/amharic-hate-speech-using-rnn-bidirectional", filename="tokenizer.joblib")
# Load the Keras model
keras_model = tf.keras.models.load_model(model_path)
# Load the tokenizer
tokenizer = joblib.load(tokenizer_path)
# Define the FastAPI application
app = FastAPI()
# Setup CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define the request body schema
class ClassifyRequest(BaseModel):
text: str
# Preprocessing function
def preprocess_text(text: str, tokenizer, max_length: int = 100):
sequences = tokenizer.texts_to_sequences([text]) # Tokenize the input text
padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
sequences, maxlen=max_length, padding="post", truncating="post"
)
return padded_sequences
# Classification route
@app.post("/textclassify")
def classify_text(request_body: ClassifyRequest):
text = request_body.text
processed_text = preprocess_text(text, tokenizer) # Preprocess the input text
prediction = keras_model.predict(processed_text) # Predict using the Keras model
label = "Hate" if prediction[0] > 0.5 else "Free" # Threshold-based classification
confidence = float(prediction[0]) # Get confidence score
# Return the result
response = {"label": label, "confidence": confidence}
return JSONResponse(content=response, status_code=201)
# Root route
@app.get("/")
def home():
return {"hello": "world"}
|