Spaces:
Running
Running
from backend.chart import mk_variations | |
from backend.config import ABS_DATASET_DOMAIN, ABS_DATASET_PATH, get_dataset_config | |
from backend.examples import audio_examples_tab, image_examples_tab, video_examples_tab | |
from flask import Flask, Response, send_from_directory, request | |
from flask_cors import CORS | |
import os | |
import logging | |
import pandas as pd | |
import json | |
from io import StringIO | |
from tools import ( | |
get_leaderboard_filters, | |
get_old_format_dataframe, | |
) # Import your function | |
import typing as tp | |
import requests | |
from urllib.parse import unquote | |
logger = logging.getLogger(__name__) | |
if not logger.hasHandlers(): | |
handler = logging.StreamHandler() | |
handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s")) | |
logger.addHandler(handler) | |
logger.setLevel(logging.INFO) | |
logger.warning("Starting the Flask app...") | |
app = Flask(__name__, static_folder="../frontend/dist", static_url_path="") | |
CORS(app) | |
def index(): | |
logger.warning("Serving index.html") | |
return send_from_directory(app.static_folder, "index.html") | |
def data_files(dataset_name): | |
""" | |
Serves csv files from S3. | |
""" | |
# Get dataset_type from query params | |
dataset_type = request.args.get("dataset_type") | |
if not dataset_type: | |
logger.error("No dataset_type provided in query parameters.") | |
return "Dataset type not specified", 400 | |
# data_dir = os.path.join(os.path.dirname(__file__), "data") | |
file_path = os.path.join(ABS_DATASET_PATH, dataset_name) + f"_{dataset_type}.csv" | |
logger.info(f"Looking for dataset file: {file_path}") | |
try: | |
df = pd.read_csv(file_path) | |
logger.info(f"Processing dataset: {dataset_name}") | |
config = get_dataset_config(dataset_name) | |
if dataset_type == "benchmark": | |
return get_leaderboard(config, df) | |
elif dataset_type == "attacks_variations": | |
return get_chart(config, df) | |
except: | |
logger.error(f"Failed to fetch file: {file_path}") | |
return "File not found", 404 | |
def example_files(type): | |
""" | |
Serve example files from S3. | |
""" | |
# Switch based on the type parameter to call the appropriate tab function | |
if type == "image": | |
result = image_examples_tab(ABS_DATASET_PATH) | |
return Response(json.dumps(result), mimetype="application/json") | |
elif type == "audio": | |
# Assuming you'll create these functions | |
result = audio_examples_tab(ABS_DATASET_PATH) | |
return Response(json.dumps(result), mimetype="application/json") | |
elif type == "video": | |
# Assuming you'll create these functions | |
result = video_examples_tab(ABS_DATASET_PATH) | |
return Response(json.dumps(result), mimetype="application/json") | |
else: | |
return "Invalid example type", 400 | |
# Add a proxy endpoint to bypass CORS issues | |
def proxy(url): | |
""" | |
Proxy endpoint to fetch remote files and serve them to the frontend. | |
This helps bypass CORS restrictions on remote resources. | |
""" | |
try: | |
# Decode the URL parameter | |
url = unquote(url) | |
# Make sure we're only proxying from trusted domains for security | |
if not url.startswith(ABS_DATASET_DOMAIN): | |
return {"error": "Only proxying from allowed domains is permitted"}, 403 | |
response = requests.get(url, stream=True) | |
if response.status_code != 200: | |
return {"error": f"Failed to fetch from {url}"}, response.status_code | |
# Create a Flask Response with the same content type as the original | |
excluded_headers = [ | |
"content-encoding", | |
"content-length", | |
"transfer-encoding", | |
"connection", | |
] | |
headers = { | |
name: value | |
for name, value in response.headers.items() | |
if name.lower() not in excluded_headers | |
} | |
# Add CORS headers | |
headers["Access-Control-Allow-Origin"] = "*" | |
return Response(response.content, response.status_code, headers) | |
except Exception as e: | |
return {"error": str(e)}, 500 | |
def get_leaderboard(config, df): | |
# Determine file type and handle accordingly | |
logger.warning(f"Processing dataset with config: {config}") | |
# This part adds on all the columns | |
df = get_old_format_dataframe(df, config["first_cols"], config["attack_scores"]) | |
groups, default_selection = get_leaderboard_filters(df, config["categories"]) | |
# Replace NaN values with None for JSON serialization | |
df = df.fillna(value="NaN") | |
# Transpose the DataFrame so each column becomes a row and column is the model | |
df = df.set_index("model").T.reset_index() | |
df = df.rename(columns={"index": "metric"}) | |
# Convert DataFrame to JSON | |
result = { | |
"groups": {group: list(metrics) for group, metrics in groups.items()}, | |
"default_selected_metrics": list(default_selection), | |
"rows": df.to_dict(orient="records"), | |
} | |
return Response(json.dumps(result), mimetype="application/json") | |
def get_chart(config, df): | |
# This function should return the chart data based on the DataFrame | |
# For now, we will just return a placeholder response | |
chart_data = mk_variations( | |
df, | |
config["attacks_with_variations"], | |
# attacks_plot_metrics, | |
# audio_attacks_with_variations, | |
) | |
return Response(json.dumps(chart_data), mimetype="application/json") | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860, debug=True, use_reloader=True) | |