Spaces:
Running
Running
import { useEffect, useState } from 'react' | |
import { | |
LineChart, | |
Line, | |
XAxis, | |
YAxis, | |
CartesianGrid, | |
Tooltip, | |
Legend, | |
ResponsiveContainer, | |
} from 'recharts' | |
import API from '../API' | |
interface DataChartProps { | |
dataset: string | |
} | |
interface Row { | |
metric: string | |
[key: string]: string | number | |
} | |
const MetricSelector = ({ | |
metrics, | |
selectedMetric, | |
onMetricChange, | |
}: { | |
metrics: Set<string> | |
selectedMetric: string | null | |
onMetricChange: (event: React.ChangeEvent<HTMLSelectElement>) => void | |
}) => { | |
return ( | |
<fieldset className="fieldset mb-4"> | |
<legend className="fieldset-legend">Metric</legend> | |
<select | |
id="metric-selector" | |
value={selectedMetric || ''} | |
onChange={onMetricChange} | |
className="select select-bordered w-full" | |
> | |
{[...metrics].map((metric) => ( | |
<option key={metric} value={metric}> | |
{metric} | |
</option> | |
))} | |
</select> | |
</fieldset> | |
) | |
} | |
const AttackSelector = ({ | |
attacks, | |
selectedAttack, | |
onAttackChange, | |
}: { | |
attacks: Set<string> | |
selectedAttack: string | null | |
onAttackChange: (event: React.ChangeEvent<HTMLSelectElement>) => void | |
}) => { | |
return ( | |
<fieldset className="fieldset mb-4"> | |
<legend className="fieldset-legend">Attack</legend> | |
<select | |
id="attack-selector" | |
value={selectedAttack || ''} | |
onChange={onAttackChange} | |
className="select select-bordered w-full" | |
> | |
{[...attacks].map((attack) => ( | |
<option key={attack} value={attack}> | |
{attack} | |
</option> | |
))} | |
</select> | |
</fieldset> | |
) | |
} | |
const DataChart = ({ dataset }: DataChartProps) => { | |
const [chartData, setChartData] = useState<Row[]>([]) | |
const [loading, setLoading] = useState(true) | |
const [error, setError] = useState<string | null>(null) | |
const [metrics, setMetrics] = useState<Set<string>>(new Set()) | |
const [attacks, setAttacks] = useState<Set<string>>(new Set()) | |
const [selectedMetric, setSelectedMetric] = useState<string | null>(null) | |
const [selectedAttack, setSelectedAttack] = useState<string | null>(null) | |
useEffect(() => { | |
API.fetchStaticFile(`data/${dataset}_attacks_variations`) | |
.then((response) => { | |
const data = JSON.parse(response) | |
const rows: Row[] = data['all_attacks_df'].map((row: any) => { | |
const newRow: Row = { ...row } | |
// Convert strength value to number if it exists and is a string | |
if (typeof newRow.strength === 'string') { | |
newRow.strength = parseFloat(newRow.strength) | |
} | |
return newRow | |
}) | |
setSelectedMetric(data['metrics'][0]) | |
setMetrics(new Set(data['metrics'])) | |
setSelectedAttack(data['attacks_with_variations'][0]) | |
setAttacks(new Set(data['attacks_with_variations'])) | |
setChartData(rows) | |
setLoading(false) | |
}) | |
.catch((err) => { | |
setError('Failed to fetch JSON: ' + err.message) | |
setLoading(false) | |
}) | |
}, []) | |
const handleMetricChange = (event: React.ChangeEvent<HTMLSelectElement>) => { | |
setSelectedMetric(event.target.value) | |
} | |
const handleAttackChange = (event: React.ChangeEvent<HTMLSelectElement>) => { | |
setSelectedAttack(event.target.value) | |
} | |
// Sort the chart data by the 'strength' field before rendering | |
const sortedChartData = chartData | |
.filter((row) => !selectedAttack || row.attack === selectedAttack) | |
.sort((a, b) => (a.strength as number) - (b.strength as number)) | |
return ( | |
<div className="max-w-4xl rounded shadow p-4 overflow-auto mb-8"> | |
<h3 className="font-bold mb-2">Data Visualization</h3> | |
{loading && <div>Loading...</div>} | |
{error && <div className="text-red-500">{error}</div>} | |
{!loading && !error && ( | |
<> | |
<div className="flex flex-col md:flex-row gap-4 mb-4"> | |
<div className="w-full md:w-1/2"> | |
<MetricSelector | |
metrics={metrics} | |
selectedMetric={selectedMetric} | |
onMetricChange={handleMetricChange} | |
/> | |
</div> | |
<div className="w-full md:w-1/2"> | |
<AttackSelector | |
attacks={attacks} | |
selectedAttack={selectedAttack} | |
onAttackChange={handleAttackChange} | |
/> | |
</div> | |
</div> | |
{chartData.length > 0 && ( | |
<div className="h-64 mb-4"> | |
<ResponsiveContainer width="100%" height="100%"> | |
<LineChart | |
data={sortedChartData} | |
margin={{ | |
top: 5, | |
right: 30, | |
left: 20, | |
bottom: 5, | |
}} | |
> | |
<CartesianGrid strokeDasharray="3 3" /> | |
<XAxis | |
dataKey="strength" | |
domain={[ | |
Math.min(...sortedChartData.map((item) => Number(item.strength))), | |
Math.max(...sortedChartData.map((item) => Number(item.strength))), | |
]} | |
type="number" | |
tickFormatter={(value) => value.toFixed(3)} | |
label={{ value: 'Strength', position: 'insideBottomRight', offset: -5 }} | |
/> | |
<YAxis | |
label={{ | |
value: selectedMetric || '', | |
angle: -90, | |
position: 'insideLeft', | |
style: { textAnchor: 'middle' }, | |
}} | |
tickFormatter={(value) => value.toFixed(3)} | |
/> | |
<Tooltip | |
contentStyle={{ | |
backgroundColor: '#2a303c', | |
borderColor: '#374151', | |
color: 'white', | |
}} | |
formatter={(value: number) => value.toFixed(3)} | |
/> | |
<Legend /> | |
{(() => { | |
// Ensure selectedMetric is not null before rendering the Line components | |
if (!selectedMetric) return null // Do not render lines if no metric is selected | |
// Get unique models from the filtered and sorted data | |
const models = new Set(sortedChartData.map((row) => row.model)) | |
// Generate different colors for each model | |
const colors = [ | |
'#8884d8', | |
'#82ca9d', | |
'#ffc658', | |
'#ff8042', | |
'#0088fe', | |
'#00C49F', | |
] | |
// Return a Line component for each model | |
return [...models].map((model, index) => { | |
return ( | |
<Line | |
key={model as string} | |
type="monotone" | |
dataKey={selectedMetric as string} // Ensure selectedMetric is a string | |
data={sortedChartData.filter((row) => row.model === model)} | |
name={model as string} | |
stroke={colors[index % colors.length]} | |
dot={false} | |
/> | |
) | |
}) | |
})()} | |
</LineChart> | |
</ResponsiveContainer> | |
</div> | |
)} | |
</> | |
)} | |
</div> | |
) | |
} | |
export default DataChart | |