File size: 5,916 Bytes
98847a8
ed37070
a1f1bf8
 
f762ee5
 
a1f1bf8
 
 
98847a8
 
 
 
a1f1bf8
ed37070
 
98847a8
f762ee5
 
0b598b9
 
 
 
f762ee5
 
a1f1bf8
f762ee5
0b598b9
a1f1bf8
0b598b9
f762ee5
 
 
98847a8
f762ee5
 
 
0b598b9
 
 
ed37070
0b598b9
 
 
 
a1f1bf8
98847a8
 
 
 
 
0b598b9
 
 
 
ed37070
 
98847a8
 
 
ed37070
98847a8
 
ed37070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98847a8
ed37070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98847a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f762ee5
98847a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from backend.chart import mk_variations
from backend.examples import audio_examples_tab, image_examples_tab, video_examples_tab
from flask import Flask, Response, send_from_directory
from flask_cors import CORS
import os
import logging
import pandas as pd
import json
from io import StringIO
from tools import (
    get_leaderboard_filters,
    get_old_format_dataframe,
)  # Import your function
import typing as tp
import requests
from urllib.parse import unquote


logger = logging.getLogger(__name__)
if not logger.hasHandlers():
    handler = logging.StreamHandler()
    handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
    logger.addHandler(handler)

logger.setLevel(logging.INFO)
logger.warning("Starting the Flask app...")

app = Flask(__name__, static_folder="../frontend/dist", static_url_path="")
CORS(app)


@app.route("/")
def index():
    logger.warning("Serving index.html")
    return send_from_directory(app.static_folder, "index.html")


@app.route("/data/<path:filename>")
def data_files(filename):
    """
    Serves csv files from the data directory.
    """
    data_dir = os.path.join(os.path.dirname(__file__), "data")
    file_path = os.path.join(data_dir, filename)
    if os.path.isfile(file_path):
        df = pd.read_csv(file_path)
        logger.info(f"Processing file: {filename}")
        if filename.endswith("benchmark.csv"):
            return get_leaderboard(df)
        elif filename.endswith("attacks_variations.csv"):
            return get_chart(df)

    return "File not found", 404


@app.route("/examples/<path:type>")
def example_files(type):
    """
    Serve example files from the examples directory.
    """

    abs_path = "https://dl.fbaipublicfiles.com/omnisealbench/"

    # Switch based on the type parameter to call the appropriate tab function
    if type == "image":
        result = image_examples_tab(abs_path)
        return Response(json.dumps(result), mimetype="application/json")
    elif type == "audio":
        # Assuming you'll create these functions
        result = audio_examples_tab(abs_path)
        return Response(json.dumps(result), mimetype="application/json")
    elif type == "video":
        # Assuming you'll create these functions
        result = video_examples_tab(abs_path)
        return Response(json.dumps(result), mimetype="application/json")
    else:
        return "Invalid example type", 400


# Add a proxy endpoint to bypass CORS issues
@app.route("/proxy/<path:url>")
def proxy(url):
    """
    Proxy endpoint to fetch remote files and serve them to the frontend.
    This helps bypass CORS restrictions on remote resources.
    """
    try:
        # Decode the URL parameter
        url = unquote(url)

        # Make sure we're only proxying from trusted domains for security
        if not url.startswith("https://dl.fbaipublicfiles.com/"):
            return {"error": "Only proxying from allowed domains is permitted"}, 403

        response = requests.get(url, stream=True)

        if response.status_code != 200:
            return {"error": f"Failed to fetch from {url}"}, response.status_code

        # Create a Flask Response with the same content type as the original
        excluded_headers = [
            "content-encoding",
            "content-length",
            "transfer-encoding",
            "connection",
        ]
        headers = {
            name: value
            for name, value in response.headers.items()
            if name.lower() not in excluded_headers
        }

        # Add CORS headers
        headers["Access-Control-Allow-Origin"] = "*"

        return Response(response.content, response.status_code, headers)
    except Exception as e:
        return {"error": str(e)}, 500


def get_leaderboard(df):
    # Determine file type and handle accordingly

    # Modify the dataframe - you'll need to define first_cols and attack_scores
    first_cols = [
        "snr",
        "sisnr",
        "stoi",
        "pesq",
    ]  # Define appropriate values based on your needs
    attack_scores = [
        "bit_acc",
        "log10_p_value",
        "TPR",
        "FPR",
    ]  # Define appropriate values based on your needs
    categories = {
        "speed": "Time",
        "updownresample": "Time",
        "echo": "Time",
        "random_noise": "Amplitude",
        "lowpass_filter": "Amplitude",
        "highpass_filter": "Amplitude",
        "bandpass_filter": "Amplitude",
        "smooth": "Amplitude",
        "boost_audio": "Amplitude",
        "duck_audio": "Amplitude",
        "shush": "Amplitude",
        "pink_noise": "Amplitude",
        "aac_compression": "Compression",
        "mp3_compression": "Compression",
    }

    # This part adds on all the columns
    df = get_old_format_dataframe(df, first_cols, attack_scores)

    groups, default_selection = get_leaderboard_filters(df, categories)

    # Replace NaN values with None for JSON serialization
    df = df.fillna(value="NaN")

    # Transpose the DataFrame so each column becomes a row and column is the model
    df = df.set_index("model").T.reset_index()
    df = df.rename(columns={"index": "metric"})

    # Convert DataFrame to JSON
    result = {
        "groups": {group: list(metrics) for group, metrics in groups.items()},
        "default_selected_metrics": list(default_selection),
        "rows": df.to_dict(orient="records"),
    }

    return Response(json.dumps(result), mimetype="application/json")


def get_chart(df):
    # This function should return the chart data based on the DataFrame
    # For now, we will just return a placeholder response
    chart_data = mk_variations(
        df,
        # attacks_plot_metrics,
        # audio_attacks_with_variations,
    )

    return Response(json.dumps(chart_data), mimetype="application/json")


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860, debug=True, use_reloader=True)