Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing.text import tokenizer_from_json | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
import numpy as np | |
import json | |
from typing import Union, List | |
app = FastAPI() | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
def load_model_and_tokenizer(): | |
global model, tokenizer | |
try: | |
# Load model | |
model = load_model('news_classifier.h5') | |
# Load tokenizer - fixing the JSON handling | |
with open('tokenizer.json', 'r') as f: | |
tokenizer_json = f.read() # Read as string | |
tokenizer = tokenizer_from_json(tokenizer_json) # Pass the string directly | |
except Exception as e: | |
print(f"Error loading model or tokenizer: {str(e)}") | |
raise e | |
# Load on startup | |
load_model_and_tokenizer() | |
class PredictionInput(BaseModel): | |
text: Union[str, List[str]] | |
class PredictionOutput(BaseModel): | |
label: str | |
score: float | |
def read_root(): | |
return { | |
"message": "News Source Classifier API", | |
"model_type": "LSTM", | |
"version": "1.0", | |
"status": "ready" if model and tokenizer else "not_loaded" | |
} | |
async def predict(input_data: PredictionInput): | |
if not model or not tokenizer: | |
try: | |
load_model_and_tokenizer() | |
except Exception as e: | |
raise HTTPException(status_code=500, detail="Model not loaded") | |
try: | |
# Handle both single string and list inputs | |
texts = input_data.text if isinstance(input_data.text, list) else [input_data.text] | |
# Preprocess | |
sequences = tokenizer.texts_to_sequences(texts) | |
padded = pad_sequences(sequences, maxlen=41) # Match your model's input length | |
# Get predictions | |
predictions = model.predict(padded, verbose=0) | |
# Process results | |
results = [] | |
for pred in predictions: | |
label = "foxnews" if pred[1] > 0.5 else "nbc" | |
score = float(pred[1] if label == "foxnews" else 1 - pred[1]) | |
results.append({ | |
"label": label, | |
"score": score | |
}) | |
# Return single result if input was single string | |
return results[0] if isinstance(input_data.text, str) else results | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def reload_model(): | |
try: | |
load_model_and_tokenizer() | |
return {"message": "Model reloaded successfully"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) |