|
import torch |
|
from transformers import AutoModelForSequenceClassification |
|
|
|
|
|
MODEL_NAME = "model.pt" |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) |
|
|
|
|
|
def predict(text): |
|
""" |
|
This function takes a text input, preprocesses it using the tokenizer, |
|
makes a prediction using the loaded model, and returns the predicted output. |
|
**Replace this function with your actual prediction logic.** |
|
""" |
|
|
|
|
|
inputs = text |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
predictions = torch.argmax(outputs.logits, dim=-1) |
|
return predictions.item() |
|
|
|
|
|
def handle_request(data): |
|
""" |
|
This function takes user input data (modify based on your UI framework), |
|
extracts the relevant text, and calls the predict function to make a prediction. |
|
""" |
|
text = data["text"] |
|
prediction = predict(text) |
|
return {"prediction": prediction} |
|
|
|
if __name__ == "__main__": |
|
from fastapi import FastAPI |
|
|
|
app = FastAPI() |
|
|
|
@app.post("/predict") |
|
async def predict_from_text(data: dict): |
|
response = handle_request(data) |
|
return response |