mlai / app.py
saifeddinemk's picture
Fixed app
4c46b15
raw
history blame
1.57 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import LlamaTokenizer, AutoModelForCausalLM
import torch
# Load the tokenizer and model
tokenizer = LlamaTokenizer.from_pretrained("WhiteRabbitNeo/WhiteRabbitNeo-13B-v1")
model = AutoModelForCausalLM.from_pretrained("WhiteRabbitNeo/WhiteRabbitNeo-13B-v1")
# Initialize the FastAPI app
app = FastAPI()
# Define a request body model for input
class LogAnalysisRequest(BaseModel):
logs: list
@app.get("/")
def hello_world():
return "Hello World"
# Define the /analyze endpoint
@app.post("/analyze")
async def analyze_logs(request: LogAnalysisRequest):
# Check if logs are provided
if not request.logs:
raise HTTPException(status_code=400, detail="No logs provided.")
# Prepare the input for the model
input_texts = ["Analyze this log for malicious activity: " + log for log in request.logs]
inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True)
# Generate predictions
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_length=100,
num_return_sequences=1
)
# Decode the predictions
results = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
# Format and return the results
response = {"analysis_results": results}
return response
# Run the FastAPI app (if running this script directly)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)