Mark Duppenthaler
Add dataset selector
54be5f9
raw
history blame
7.64 kB
import { useEffect, useState } from 'react'
import {
LineChart,
Line,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
Legend,
ResponsiveContainer,
} from 'recharts'
import API from '../API'
interface DataChartProps {
dataset: string
}
interface Row {
metric: string
[key: string]: string | number
}
const MetricSelector = ({
metrics,
selectedMetric,
onMetricChange,
}: {
metrics: Set<string>
selectedMetric: string | null
onMetricChange: (event: React.ChangeEvent<HTMLSelectElement>) => void
}) => {
return (
<fieldset className="fieldset mb-4">
<legend className="fieldset-legend">Metric</legend>
<select
id="metric-selector"
value={selectedMetric || ''}
onChange={onMetricChange}
className="select select-bordered w-full"
>
{[...metrics].map((metric) => (
<option key={metric} value={metric}>
{metric}
</option>
))}
</select>
</fieldset>
)
}
const AttackSelector = ({
attacks,
selectedAttack,
onAttackChange,
}: {
attacks: Set<string>
selectedAttack: string | null
onAttackChange: (event: React.ChangeEvent<HTMLSelectElement>) => void
}) => {
return (
<fieldset className="fieldset mb-4">
<legend className="fieldset-legend">Attack</legend>
<select
id="attack-selector"
value={selectedAttack || ''}
onChange={onAttackChange}
className="select select-bordered w-full"
>
{[...attacks].map((attack) => (
<option key={attack} value={attack}>
{attack}
</option>
))}
</select>
</fieldset>
)
}
const DataChart = ({ dataset }: DataChartProps) => {
const [chartData, setChartData] = useState<Row[]>([])
const [loading, setLoading] = useState(true)
const [error, setError] = useState<string | null>(null)
const [metrics, setMetrics] = useState<Set<string>>(new Set())
const [attacks, setAttacks] = useState<Set<string>>(new Set())
const [selectedMetric, setSelectedMetric] = useState<string | null>(null)
const [selectedAttack, setSelectedAttack] = useState<string | null>(null)
useEffect(() => {
API.fetchStaticFile(`data/${dataset}_attacks_variations`)
.then((response) => {
const data = JSON.parse(response)
const rows: Row[] = data['all_attacks_df'].map((row: any) => {
const newRow: Row = { ...row }
// Convert strength value to number if it exists and is a string
if (typeof newRow.strength === 'string') {
newRow.strength = parseFloat(newRow.strength)
}
return newRow
})
setSelectedMetric(data['metrics'][0])
setMetrics(new Set(data['metrics']))
setSelectedAttack(data['attacks_with_variations'][0])
setAttacks(new Set(data['attacks_with_variations']))
setChartData(rows)
setLoading(false)
})
.catch((err) => {
setError('Failed to fetch JSON: ' + err.message)
setLoading(false)
})
}, [])
const handleMetricChange = (event: React.ChangeEvent<HTMLSelectElement>) => {
setSelectedMetric(event.target.value)
}
const handleAttackChange = (event: React.ChangeEvent<HTMLSelectElement>) => {
setSelectedAttack(event.target.value)
}
// Sort the chart data by the 'strength' field before rendering
const sortedChartData = chartData
.filter((row) => !selectedAttack || row.attack === selectedAttack)
.sort((a, b) => (a.strength as number) - (b.strength as number))
return (
<div className="max-w-4xl rounded shadow p-4 overflow-auto mb-8">
<h3 className="font-bold mb-2">Data Visualization</h3>
{loading && <div>Loading...</div>}
{error && <div className="text-red-500">{error}</div>}
{!loading && !error && (
<>
<div className="flex flex-col md:flex-row gap-4 mb-4">
<div className="w-full md:w-1/2">
<MetricSelector
metrics={metrics}
selectedMetric={selectedMetric}
onMetricChange={handleMetricChange}
/>
</div>
<div className="w-full md:w-1/2">
<AttackSelector
attacks={attacks}
selectedAttack={selectedAttack}
onAttackChange={handleAttackChange}
/>
</div>
</div>
{chartData.length > 0 && (
<div className="h-64 mb-4">
<ResponsiveContainer width="100%" height="100%">
<LineChart
data={sortedChartData}
margin={{
top: 5,
right: 30,
left: 20,
bottom: 5,
}}
>
<CartesianGrid strokeDasharray="3 3" />
<XAxis
dataKey="strength"
domain={[
Math.min(...sortedChartData.map((item) => Number(item.strength))),
Math.max(...sortedChartData.map((item) => Number(item.strength))),
]}
type="number"
tickFormatter={(value) => value.toFixed(3)}
label={{ value: 'Strength', position: 'insideBottomRight', offset: -5 }}
/>
<YAxis
label={{
value: selectedMetric || '',
angle: -90,
position: 'insideLeft',
style: { textAnchor: 'middle' },
}}
tickFormatter={(value) => value.toFixed(3)}
/>
<Tooltip
contentStyle={{
backgroundColor: '#2a303c',
borderColor: '#374151',
color: 'white',
}}
formatter={(value: number) => value.toFixed(3)}
/>
<Legend />
{(() => {
// Ensure selectedMetric is not null before rendering the Line components
if (!selectedMetric) return null // Do not render lines if no metric is selected
// Get unique models from the filtered and sorted data
const models = new Set(sortedChartData.map((row) => row.model))
// Generate different colors for each model
const colors = [
'#8884d8',
'#82ca9d',
'#ffc658',
'#ff8042',
'#0088fe',
'#00C49F',
]
// Return a Line component for each model
return [...models].map((model, index) => {
return (
<Line
key={model as string}
type="monotone"
dataKey={selectedMetric as string} // Ensure selectedMetric is a string
data={sortedChartData.filter((row) => row.model === model)}
name={model as string}
stroke={colors[index % colors.length]}
dot={false}
/>
)
})
})()}
</LineChart>
</ResponsiveContainer>
</div>
)}
</>
)}
</div>
)
}
export default DataChart