File size: 1,572 Bytes
7826a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c46b15
 
 
 
7826a83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8de6232
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import LlamaTokenizer, AutoModelForCausalLM
import torch

# Load the tokenizer and model
tokenizer = LlamaTokenizer.from_pretrained("WhiteRabbitNeo/WhiteRabbitNeo-13B-v1")
model = AutoModelForCausalLM.from_pretrained("WhiteRabbitNeo/WhiteRabbitNeo-13B-v1")

# Initialize the FastAPI app
app = FastAPI()

# Define a request body model for input
class LogAnalysisRequest(BaseModel):
    logs: list


@app.get("/")
def hello_world():
    return "Hello World"
# Define the /analyze endpoint
@app.post("/analyze")
async def analyze_logs(request: LogAnalysisRequest):
    # Check if logs are provided
    if not request.logs:
        raise HTTPException(status_code=400, detail="No logs provided.")

    # Prepare the input for the model
    input_texts = ["Analyze this log for malicious activity: " + log for log in request.logs]
    inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True)

    # Generate predictions
    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            max_length=100,
            num_return_sequences=1
        )

    # Decode the predictions
    results = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

    # Format and return the results
    response = {"analysis_results": results}
    return response

# Run the FastAPI app (if running this script directly)
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)