Spaces:
Running
Running
File size: 5,916 Bytes
98847a8 ed37070 a1f1bf8 f762ee5 a1f1bf8 98847a8 a1f1bf8 ed37070 98847a8 f762ee5 0b598b9 f762ee5 a1f1bf8 f762ee5 0b598b9 a1f1bf8 0b598b9 f762ee5 98847a8 f762ee5 0b598b9 ed37070 0b598b9 a1f1bf8 98847a8 0b598b9 ed37070 98847a8 ed37070 98847a8 ed37070 98847a8 ed37070 98847a8 f762ee5 98847a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
from backend.chart import mk_variations
from backend.examples import audio_examples_tab, image_examples_tab, video_examples_tab
from flask import Flask, Response, send_from_directory
from flask_cors import CORS
import os
import logging
import pandas as pd
import json
from io import StringIO
from tools import (
get_leaderboard_filters,
get_old_format_dataframe,
) # Import your function
import typing as tp
import requests
from urllib.parse import unquote
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.warning("Starting the Flask app...")
app = Flask(__name__, static_folder="../frontend/dist", static_url_path="")
CORS(app)
@app.route("/")
def index():
logger.warning("Serving index.html")
return send_from_directory(app.static_folder, "index.html")
@app.route("/data/<path:filename>")
def data_files(filename):
"""
Serves csv files from the data directory.
"""
data_dir = os.path.join(os.path.dirname(__file__), "data")
file_path = os.path.join(data_dir, filename)
if os.path.isfile(file_path):
df = pd.read_csv(file_path)
logger.info(f"Processing file: {filename}")
if filename.endswith("benchmark.csv"):
return get_leaderboard(df)
elif filename.endswith("attacks_variations.csv"):
return get_chart(df)
return "File not found", 404
@app.route("/examples/<path:type>")
def example_files(type):
"""
Serve example files from the examples directory.
"""
abs_path = "https://dl.fbaipublicfiles.com/omnisealbench/"
# Switch based on the type parameter to call the appropriate tab function
if type == "image":
result = image_examples_tab(abs_path)
return Response(json.dumps(result), mimetype="application/json")
elif type == "audio":
# Assuming you'll create these functions
result = audio_examples_tab(abs_path)
return Response(json.dumps(result), mimetype="application/json")
elif type == "video":
# Assuming you'll create these functions
result = video_examples_tab(abs_path)
return Response(json.dumps(result), mimetype="application/json")
else:
return "Invalid example type", 400
# Add a proxy endpoint to bypass CORS issues
@app.route("/proxy/<path:url>")
def proxy(url):
"""
Proxy endpoint to fetch remote files and serve them to the frontend.
This helps bypass CORS restrictions on remote resources.
"""
try:
# Decode the URL parameter
url = unquote(url)
# Make sure we're only proxying from trusted domains for security
if not url.startswith("https://dl.fbaipublicfiles.com/"):
return {"error": "Only proxying from allowed domains is permitted"}, 403
response = requests.get(url, stream=True)
if response.status_code != 200:
return {"error": f"Failed to fetch from {url}"}, response.status_code
# Create a Flask Response with the same content type as the original
excluded_headers = [
"content-encoding",
"content-length",
"transfer-encoding",
"connection",
]
headers = {
name: value
for name, value in response.headers.items()
if name.lower() not in excluded_headers
}
# Add CORS headers
headers["Access-Control-Allow-Origin"] = "*"
return Response(response.content, response.status_code, headers)
except Exception as e:
return {"error": str(e)}, 500
def get_leaderboard(df):
# Determine file type and handle accordingly
# Modify the dataframe - you'll need to define first_cols and attack_scores
first_cols = [
"snr",
"sisnr",
"stoi",
"pesq",
] # Define appropriate values based on your needs
attack_scores = [
"bit_acc",
"log10_p_value",
"TPR",
"FPR",
] # Define appropriate values based on your needs
categories = {
"speed": "Time",
"updownresample": "Time",
"echo": "Time",
"random_noise": "Amplitude",
"lowpass_filter": "Amplitude",
"highpass_filter": "Amplitude",
"bandpass_filter": "Amplitude",
"smooth": "Amplitude",
"boost_audio": "Amplitude",
"duck_audio": "Amplitude",
"shush": "Amplitude",
"pink_noise": "Amplitude",
"aac_compression": "Compression",
"mp3_compression": "Compression",
}
# This part adds on all the columns
df = get_old_format_dataframe(df, first_cols, attack_scores)
groups, default_selection = get_leaderboard_filters(df, categories)
# Replace NaN values with None for JSON serialization
df = df.fillna(value="NaN")
# Transpose the DataFrame so each column becomes a row and column is the model
df = df.set_index("model").T.reset_index()
df = df.rename(columns={"index": "metric"})
# Convert DataFrame to JSON
result = {
"groups": {group: list(metrics) for group, metrics in groups.items()},
"default_selected_metrics": list(default_selection),
"rows": df.to_dict(orient="records"),
}
return Response(json.dumps(result), mimetype="application/json")
def get_chart(df):
# This function should return the chart data based on the DataFrame
# For now, we will just return a placeholder response
chart_data = mk_variations(
df,
# attacks_plot_metrics,
# audio_attacks_with_variations,
)
return Response(json.dumps(chart_data), mimetype="application/json")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True, use_reloader=True)
|