File size: 2,870 Bytes
60dc372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f3a7a6
60dc372
8f3a7a6
 
60dc372
8f3a7a6
 
 
60dc372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.text import tokenizer_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
import json
from typing import Union, List

app = FastAPI()

# Global variables for model and tokenizer
model = None
tokenizer = None

def load_model_and_tokenizer():
    global model, tokenizer
    try:
        # Load model
        model = load_model('news_classifier.h5')
        
        # Load tokenizer - fixing the JSON handling
        with open('tokenizer.json', 'r') as f:
            tokenizer_json = f.read()  # Read as string
            tokenizer = tokenizer_from_json(tokenizer_json)  # Pass the string directly
            
    except Exception as e:
        print(f"Error loading model or tokenizer: {str(e)}")
        raise e

# Load on startup
load_model_and_tokenizer()

class PredictionInput(BaseModel):
    text: Union[str, List[str]]

class PredictionOutput(BaseModel):
    label: str
    score: float

@app.get("/")
def read_root():
    return {
        "message": "News Source Classifier API",
        "model_type": "LSTM",
        "version": "1.0",
        "status": "ready" if model and tokenizer else "not_loaded"
    }

@app.post("/predict", response_model=Union[PredictionOutput, List[PredictionOutput]])
async def predict(input_data: PredictionInput):
    if not model or not tokenizer:
        try:
            load_model_and_tokenizer()
        except Exception as e:
            raise HTTPException(status_code=500, detail="Model not loaded")
    
    try:
        # Handle both single string and list inputs
        texts = input_data.text if isinstance(input_data.text, list) else [input_data.text]
        
        # Preprocess
        sequences = tokenizer.texts_to_sequences(texts)
        padded = pad_sequences(sequences, maxlen=41)  # Match your model's input length
        
        # Get predictions
        predictions = model.predict(padded, verbose=0)
        
        # Process results
        results = []
        for pred in predictions:
            label = "foxnews" if pred[1] > 0.5 else "nbc"
            score = float(pred[1] if label == "foxnews" else 1 - pred[1])
            results.append({
                "label": label,
                "score": score
            })
        
        # Return single result if input was single string
        return results[0] if isinstance(input_data.text, str) else results
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/reload")
async def reload_model():
    try:
        load_model_and_tokenizer()
        return {"message": "Model reloaded successfully"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))