Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import LlamaTokenizer, AutoModelForCausalLM | |
import torch | |
# Load the tokenizer and model | |
tokenizer = LlamaTokenizer.from_pretrained("WhiteRabbitNeo/WhiteRabbitNeo-13B-v1") | |
model = AutoModelForCausalLM.from_pretrained("WhiteRabbitNeo/WhiteRabbitNeo-13B-v1") | |
# Initialize the FastAPI app | |
app = FastAPI() | |
# Define a request body model for input | |
class LogAnalysisRequest(BaseModel): | |
logs: list | |
def hello_world(): | |
return "Hello World" | |
# Define the /analyze endpoint | |
async def analyze_logs(request: LogAnalysisRequest): | |
# Check if logs are provided | |
if not request.logs: | |
raise HTTPException(status_code=400, detail="No logs provided.") | |
# Prepare the input for the model | |
input_texts = ["Analyze this log for malicious activity: " + log for log in request.logs] | |
inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True) | |
# Generate predictions | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs["input_ids"], | |
max_length=100, | |
num_return_sequences=1 | |
) | |
# Decode the predictions | |
results = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
# Format and return the results | |
response = {"analysis_results": results} | |
return response | |
# Run the FastAPI app (if running this script directly) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |