Spaces:
Running
Running
File size: 5,637 Bytes
98847a8 b087e88 ed37070 b087e88 a1f1bf8 f762ee5 a1f1bf8 98847a8 a1f1bf8 ed37070 98847a8 f762ee5 0b598b9 f762ee5 a1f1bf8 f762ee5 0b598b9 a1f1bf8 0b598b9 f762ee5 98847a8 f762ee5 54be5f9 0b598b9 b087e88 0b598b9 b087e88 54be5f9 b087e88 a1f1bf8 54be5f9 b087e88 0b598b9 ed37070 98847a8 b087e88 98847a8 ed37070 b087e88 ed37070 b087e88 ed37070 b087e88 ed37070 98847a8 ed37070 b087e88 ed37070 98847a8 b087e88 98847a8 b087e88 98847a8 54be5f9 98847a8 54be5f9 98847a8 b087e88 98847a8 b087e88 98847a8 f762ee5 98847a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from backend.chart import mk_variations
from backend.config import ABS_DATASET_DOMAIN, ABS_DATASET_PATH, get_dataset_config
from backend.examples import audio_examples_tab, image_examples_tab, video_examples_tab
from flask import Flask, Response, send_from_directory, request
from flask_cors import CORS
import os
import logging
import pandas as pd
import json
from io import StringIO
from tools import (
get_leaderboard_filters,
get_old_format_dataframe,
) # Import your function
import typing as tp
import requests
from urllib.parse import unquote
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.warning("Starting the Flask app...")
app = Flask(__name__, static_folder="../frontend/dist", static_url_path="")
CORS(app)
@app.route("/")
def index():
logger.warning("Serving index.html")
return send_from_directory(app.static_folder, "index.html")
@app.route("/data/<path:dataset_name>")
def data_files(dataset_name):
"""
Serves csv files from S3.
"""
# Get dataset_type from query params
dataset_type = request.args.get("dataset_type")
if not dataset_type:
logger.error("No dataset_type provided in query parameters.")
return "Dataset type not specified", 400
# data_dir = os.path.join(os.path.dirname(__file__), "data")
file_path = os.path.join(ABS_DATASET_PATH, dataset_name) + f"_{dataset_type}.csv"
logger.info(f"Looking for dataset file: {file_path}")
try:
df = pd.read_csv(file_path)
logger.info(f"Processing dataset: {dataset_name}")
config = get_dataset_config(dataset_name)
if dataset_type == "benchmark":
return get_leaderboard(config, df)
elif dataset_type == "attacks_variations":
return get_chart(config, df)
except:
logger.error(f"Failed to fetch file: {file_path}")
return "File not found", 404
@app.route("/examples/<path:type>")
def example_files(type):
"""
Serve example files from S3.
"""
# Switch based on the type parameter to call the appropriate tab function
if type == "image":
result = image_examples_tab(ABS_DATASET_PATH)
return Response(json.dumps(result), mimetype="application/json")
elif type == "audio":
# Assuming you'll create these functions
result = audio_examples_tab(ABS_DATASET_PATH)
return Response(json.dumps(result), mimetype="application/json")
elif type == "video":
# Assuming you'll create these functions
result = video_examples_tab(ABS_DATASET_PATH)
return Response(json.dumps(result), mimetype="application/json")
else:
return "Invalid example type", 400
# Add a proxy endpoint to bypass CORS issues
@app.route("/proxy/<path:url>")
def proxy(url):
"""
Proxy endpoint to fetch remote files and serve them to the frontend.
This helps bypass CORS restrictions on remote resources.
"""
try:
# Decode the URL parameter
url = unquote(url)
# Make sure we're only proxying from trusted domains for security
if not url.startswith(ABS_DATASET_DOMAIN):
return {"error": "Only proxying from allowed domains is permitted"}, 403
response = requests.get(url, stream=True)
if response.status_code != 200:
return {"error": f"Failed to fetch from {url}"}, response.status_code
# Create a Flask Response with the same content type as the original
excluded_headers = [
"content-encoding",
"content-length",
"transfer-encoding",
"connection",
]
headers = {
name: value
for name, value in response.headers.items()
if name.lower() not in excluded_headers
}
# Add CORS headers
headers["Access-Control-Allow-Origin"] = "*"
return Response(response.content, response.status_code, headers)
except Exception as e:
return {"error": str(e)}, 500
def get_leaderboard(config, df):
# Determine file type and handle accordingly
logger.warning(f"Processing dataset with config: {config}")
# This part adds on all the columns
df = get_old_format_dataframe(df, config["first_cols"], config["attack_scores"])
groups, default_selection = get_leaderboard_filters(df, config["categories"])
# Replace NaN values with None for JSON serialization
df = df.fillna(value="NaN")
# Transpose the DataFrame so each column becomes a row and column is the model
df = df.set_index("model").T.reset_index()
df = df.rename(columns={"index": "metric"})
# Convert DataFrame to JSON
result = {
"groups": {group: list(metrics) for group, metrics in groups.items()},
"default_selected_metrics": list(default_selection),
"rows": df.to_dict(orient="records"),
}
return Response(json.dumps(result), mimetype="application/json")
def get_chart(config, df):
# This function should return the chart data based on the DataFrame
# For now, we will just return a placeholder response
chart_data = mk_variations(
df,
config["attacks_with_variations"],
# attacks_plot_metrics,
# audio_attacks_with_variations,
)
return Response(json.dumps(chart_data), mimetype="application/json")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True, use_reloader=True)
|