Spaces:
Sleeping
Sleeping
File size: 1,572 Bytes
7826a83 4c46b15 7826a83 8de6232 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import LlamaTokenizer, AutoModelForCausalLM
import torch
# Load the tokenizer and model
tokenizer = LlamaTokenizer.from_pretrained("WhiteRabbitNeo/WhiteRabbitNeo-13B-v1")
model = AutoModelForCausalLM.from_pretrained("WhiteRabbitNeo/WhiteRabbitNeo-13B-v1")
# Initialize the FastAPI app
app = FastAPI()
# Define a request body model for input
class LogAnalysisRequest(BaseModel):
logs: list
@app.get("/")
def hello_world():
return "Hello World"
# Define the /analyze endpoint
@app.post("/analyze")
async def analyze_logs(request: LogAnalysisRequest):
# Check if logs are provided
if not request.logs:
raise HTTPException(status_code=400, detail="No logs provided.")
# Prepare the input for the model
input_texts = ["Analyze this log for malicious activity: " + log for log in request.logs]
inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True)
# Generate predictions
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_length=100,
num_return_sequences=1
)
# Decode the predictions
results = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
# Format and return the results
response = {"analysis_results": results}
return response
# Run the FastAPI app (if running this script directly)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|