from fastapi import FastAPI, HTTPException from pydantic import BaseModel from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing.text import tokenizer_from_json from tensorflow.keras.preprocessing.sequence import pad_sequences import numpy as np import json from typing import Union, List app = FastAPI() # Global variables for model and tokenizer model = None tokenizer = None def load_model_and_tokenizer(): global model, tokenizer try: model = load_model('news_classifier.h5') with open('tokenizer.json', 'r') as f: tokenizer_data = json.load(f) tokenizer = tokenizer_from_json(tokenizer_data) except Exception as e: print(f"Error loading model or tokenizer: {str(e)}") raise e # Load on startup load_model_and_tokenizer() class PredictionInput(BaseModel): text: Union[str, List[str]] class PredictionOutput(BaseModel): label: str score: float @app.get("/") def read_root(): return { "message": "News Source Classifier API", "model_type": "LSTM", "version": "1.0", "status": "ready" if model and tokenizer else "not_loaded" } @app.post("/predict", response_model=Union[PredictionOutput, List[PredictionOutput]]) async def predict(input_data: PredictionInput): if not model or not tokenizer: try: load_model_and_tokenizer() except Exception as e: raise HTTPException(status_code=500, detail="Model not loaded") try: # Handle both single string and list inputs texts = input_data.text if isinstance(input_data.text, list) else [input_data.text] # Preprocess sequences = tokenizer.texts_to_sequences(texts) padded = pad_sequences(sequences, maxlen=41) # Match your model's input length # Get predictions predictions = model.predict(padded, verbose=0) # Process results results = [] for pred in predictions: label = "foxnews" if pred[1] > 0.5 else "nbc" score = float(pred[1] if label == "foxnews" else 1 - pred[1]) results.append({ "label": label, "score": score }) # Return single result if input was single string return results[0] if isinstance(input_data.text, str) else results except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/reload") async def reload_model(): try: load_model_and_tokenizer() return {"message": "Model reloaded successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e))