File size: 5,637 Bytes
98847a8
b087e88
ed37070
b087e88
a1f1bf8
f762ee5
 
a1f1bf8
 
 
98847a8
 
 
 
a1f1bf8
ed37070
 
98847a8
f762ee5
 
0b598b9
 
 
 
f762ee5
 
a1f1bf8
f762ee5
0b598b9
a1f1bf8
0b598b9
f762ee5
 
 
98847a8
f762ee5
 
 
54be5f9
 
0b598b9
b087e88
0b598b9
b087e88
 
 
 
 
 
 
 
54be5f9
b087e88
a1f1bf8
54be5f9
b087e88
 
 
 
 
 
 
 
0b598b9
 
ed37070
 
98847a8
b087e88
98847a8
ed37070
 
 
b087e88
ed37070
 
 
b087e88
ed37070
 
 
b087e88
ed37070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98847a8
ed37070
b087e88
ed37070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98847a8
 
b087e88
98847a8
b087e88
98847a8
 
54be5f9
98847a8
54be5f9
98847a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b087e88
98847a8
 
 
 
b087e88
98847a8
 
 
 
 
 
 
f762ee5
98847a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from backend.chart import mk_variations
from backend.config import ABS_DATASET_DOMAIN, ABS_DATASET_PATH, get_dataset_config
from backend.examples import audio_examples_tab, image_examples_tab, video_examples_tab
from flask import Flask, Response, send_from_directory, request
from flask_cors import CORS
import os
import logging
import pandas as pd
import json
from io import StringIO
from tools import (
    get_leaderboard_filters,
    get_old_format_dataframe,
)  # Import your function
import typing as tp
import requests
from urllib.parse import unquote


logger = logging.getLogger(__name__)
if not logger.hasHandlers():
    handler = logging.StreamHandler()
    handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
    logger.addHandler(handler)

logger.setLevel(logging.INFO)
logger.warning("Starting the Flask app...")

app = Flask(__name__, static_folder="../frontend/dist", static_url_path="")
CORS(app)


@app.route("/")
def index():
    logger.warning("Serving index.html")
    return send_from_directory(app.static_folder, "index.html")


@app.route("/data/<path:dataset_name>")
def data_files(dataset_name):
    """
    Serves csv files from S3.
    """
    # Get dataset_type from query params
    dataset_type = request.args.get("dataset_type")
    if not dataset_type:
        logger.error("No dataset_type provided in query parameters.")
        return "Dataset type not specified", 400

    # data_dir = os.path.join(os.path.dirname(__file__), "data")
    file_path = os.path.join(ABS_DATASET_PATH, dataset_name) + f"_{dataset_type}.csv"
    logger.info(f"Looking for dataset file: {file_path}")
    try:
        df = pd.read_csv(file_path)
        logger.info(f"Processing dataset: {dataset_name}")
        config = get_dataset_config(dataset_name)
        if dataset_type == "benchmark":
            return get_leaderboard(config, df)
        elif dataset_type == "attacks_variations":
            return get_chart(config, df)
    except:
        logger.error(f"Failed to fetch file: {file_path}")
        return "File not found", 404


@app.route("/examples/<path:type>")
def example_files(type):
    """
    Serve example files from S3.
    """

    # Switch based on the type parameter to call the appropriate tab function
    if type == "image":
        result = image_examples_tab(ABS_DATASET_PATH)
        return Response(json.dumps(result), mimetype="application/json")
    elif type == "audio":
        # Assuming you'll create these functions
        result = audio_examples_tab(ABS_DATASET_PATH)
        return Response(json.dumps(result), mimetype="application/json")
    elif type == "video":
        # Assuming you'll create these functions
        result = video_examples_tab(ABS_DATASET_PATH)
        return Response(json.dumps(result), mimetype="application/json")
    else:
        return "Invalid example type", 400


# Add a proxy endpoint to bypass CORS issues
@app.route("/proxy/<path:url>")
def proxy(url):
    """
    Proxy endpoint to fetch remote files and serve them to the frontend.
    This helps bypass CORS restrictions on remote resources.
    """
    try:
        # Decode the URL parameter
        url = unquote(url)

        # Make sure we're only proxying from trusted domains for security
        if not url.startswith(ABS_DATASET_DOMAIN):
            return {"error": "Only proxying from allowed domains is permitted"}, 403

        response = requests.get(url, stream=True)

        if response.status_code != 200:
            return {"error": f"Failed to fetch from {url}"}, response.status_code

        # Create a Flask Response with the same content type as the original
        excluded_headers = [
            "content-encoding",
            "content-length",
            "transfer-encoding",
            "connection",
        ]
        headers = {
            name: value
            for name, value in response.headers.items()
            if name.lower() not in excluded_headers
        }

        # Add CORS headers
        headers["Access-Control-Allow-Origin"] = "*"

        return Response(response.content, response.status_code, headers)
    except Exception as e:
        return {"error": str(e)}, 500


def get_leaderboard(config, df):
    # Determine file type and handle accordingly
    logger.warning(f"Processing dataset with config: {config}")

    # This part adds on all the columns
    df = get_old_format_dataframe(df, config["first_cols"], config["attack_scores"])

    groups, default_selection = get_leaderboard_filters(df, config["categories"])

    # Replace NaN values with None for JSON serialization
    df = df.fillna(value="NaN")

    # Transpose the DataFrame so each column becomes a row and column is the model
    df = df.set_index("model").T.reset_index()
    df = df.rename(columns={"index": "metric"})

    # Convert DataFrame to JSON
    result = {
        "groups": {group: list(metrics) for group, metrics in groups.items()},
        "default_selected_metrics": list(default_selection),
        "rows": df.to_dict(orient="records"),
    }

    return Response(json.dumps(result), mimetype="application/json")


def get_chart(config, df):
    # This function should return the chart data based on the DataFrame
    # For now, we will just return a placeholder response
    chart_data = mk_variations(
        df,
        config["attacks_with_variations"],
        # attacks_plot_metrics,
        # audio_attacks_with_variations,
    )

    return Response(json.dumps(chart_data), mimetype="application/json")


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860, debug=True, use_reloader=True)