auditforge / cwe_api /main.py
Kaballas's picture
initialize project structure with essential configurations and components
56b6519
import ssl
import hashlib
import requests
import os
import shutil
import zipfile
from pathlib import Path
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
from inferencer import inferencer
from cvss_inferencer import inferencer as cvss_inferencer
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Permitir solicitudes desde cualquier origen
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"], # M茅todos HTTP permitidos
allow_headers=["*"], # Cabeceras permitidas
)
LOCAL_FILE_PATH = Path("/app/modelo_cwe.zip")
REMOTE_FILE_URL = "https://drive.usercontent.google.com/download?id=1OtRNObv-Il2B5nDnpzMSGj_yBJAlskuS&export=download&confirm="
CWE_MODEL_FOLDER_PATH = Path("/app/modelo_cwe")
APP_PATH = Path("/app")
BYTE_RANGE = (0, 10485759) # 10 MB
def calculate_local_checksum(file_path, byte_limit):
sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f:
chunk = f.read(byte_limit)
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
def calculate_remote_checksum(url, byte_range):
headers = {'Range': f'bytes={byte_range[0]}-{byte_range[1]}'}
try:
response = requests.get(url, headers=headers, stream=True)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise Exception(f"Error al obtener el archivo remoto: {str(e)}") from e
sha256_hash = hashlib.sha256()
for chunk in response.iter_content(chunk_size=8192):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
class VulnerabilityRequest(BaseModel):
vuln: str
@app.post("/classify")
async def classify_vulnerability(vuln_request: VulnerabilityRequest):
vuln = vuln_request.vuln
result = inferencer(vuln)
return {"result": result}
@app.post("/cvss")
async def classify_vulnerability(vuln_request: VulnerabilityRequest):
vuln = vuln_request.vuln
result = cvss_inferencer(vuln)
return {"result": result}
@app.get("/")
async def read_root():
example_vuln = (
"Los dispositivos de CPU Siemens SIMATIC S7-300 permiten a los atacantes remotos causar una denegaci贸n de servicio "
"(transici贸n de modo de defecto) a trav茅s de paquetes elaborados en (1) puerto TCP 102 o (2) Profibus."
)
result = inferencer(example_vuln)
return {"example_vuln": example_vuln, "result": result}
@app.get("/check_cwe_update")
async def check_cwe_update():
try:
local_checksum = calculate_local_checksum(LOCAL_FILE_PATH, byte_limit=BYTE_RANGE[1] + 1)
remote_checksum = calculate_remote_checksum(REMOTE_FILE_URL, BYTE_RANGE)
match = local_checksum == remote_checksum
return {"checksum_match": match}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
@app.post("/update_cwe_model")
async def update_cwe_model():
backup_file_path = f"{LOCAL_FILE_PATH}.bak"
try:
if os.path.exists(LOCAL_FILE_PATH):
if os.path.exists(backup_file_path):
os.remove(backup_file_path)
shutil.move(LOCAL_FILE_PATH, backup_file_path)
response = requests.get(REMOTE_FILE_URL, stream=True)
if response.status_code == 200:
with open(LOCAL_FILE_PATH, "wb") as local_zip:
for chunk in response.iter_content(chunk_size=8192):
local_zip.write(chunk)
else:
if os.path.exists(backup_file_path):
shutil.move(backup_file_path, LOCAL_FILE_PATH)
raise HTTPException(status_code=500, detail="Failed to download CWE model")
if os.path.exists(CWE_MODEL_FOLDER_PATH):
shutil.rmtree(CWE_MODEL_FOLDER_PATH)
with zipfile.ZipFile(LOCAL_FILE_PATH, 'r') as zip_ref:
zip_ref.extractall(APP_PATH)
if os.path.exists(backup_file_path):
os.remove(backup_file_path)
return {"status": "success", "message": "Updated CWE model successfully"}
except Exception as e:
if os.path.exists(backup_file_path):
shutil.move(backup_file_path, LOCAL_FILE_PATH)
raise HTTPException(status_code=500, detail=str(e)) from e
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
ssl_keyfile="ssl/key.pem",
ssl_certfile="ssl/cert.pem",
ssl_version=ssl.PROTOCOL_TLS_SERVER
)