File size: 2,012 Bytes
01ad777
 
 
 
 
 
 
 
 
 
 
 
 
f541ec2
 
 
 
 
 
01ad777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import tensorflow as tf
import joblib
import numpy as np
from huggingface_hub import hf_hub_download

# Load the model and tokenizer from Hugging Face Hub
model_path = hf_hub_download(repo_id="rio3210/amharic-hate-speech-using-rnn-bidirectional", filename="amharic_hate_speech_rnn_model.keras")
tokenizer_path = hf_hub_download(repo_id="rio3210/amharic-hate-speech-using-rnn-bidirectional", filename="tokenizer.joblib")

# Load the Keras model
keras_model = tf.keras.models.load_model(model_path)

# Load the tokenizer
tokenizer = joblib.load(tokenizer_path)

# Define the FastAPI application
app = FastAPI()

# Setup CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Define the request body schema
class ClassifyRequest(BaseModel):
    text: str

# Preprocessing function
def preprocess_text(text: str, tokenizer, max_length: int = 100):
    sequences = tokenizer.texts_to_sequences([text])  # Tokenize the input text
    padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
        sequences, maxlen=max_length, padding="post", truncating="post"
    )
    return padded_sequences

# Classification route
@app.post("/textclassify")
def classify_text(request_body: ClassifyRequest):
    text = request_body.text
    processed_text = preprocess_text(text, tokenizer)  # Preprocess the input text
    prediction = keras_model.predict(processed_text)  # Predict using the Keras model
    label = "Hate" if prediction[0] > 0.5 else "Free"  # Threshold-based classification
    confidence = float(prediction[0])  # Get confidence score

    # Return the result
    response = {"label": label, "confidence": confidence}
    return JSONResponse(content=response, status_code=201)

# Root route
@app.get("/")
def home():
    return {"hello": "world"}