Spaces:
Running
Running
File size: 4,386 Bytes
cb7223a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
# routers/visualize.py
import os
import logging
from fastapi import APIRouter, HTTPException
from fastapi.responses import FileResponse
from schemas.visualize import (
VisualizePCARequest,
VisualizeMeanDiffRequest,
VisualizeHeatmapRequest,
)
from utils.visualize_pca import (
run_visualize_pca,
run_visualize_mean_diff,
run_visualize_heatmap,
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
router = APIRouter(
prefix="/visualize",
tags=["visualization"],
)
@router.post(
"/pca",
summary="Generates and returns the PCA visualization of activations",
response_class=FileResponse,
)
async def visualize_pca_endpoint(req: VisualizePCARequest):
"""
Receives the parameters, calls the wrapper for optipfair.bias.visualize_pca,
and returns the resulting PNG/SVG image.
"""
# 1. Execute the image generation and get the file path
try:
filepath = run_visualize_pca(
model_name=req.model_name,
prompt_pair=tuple(req.prompt_pair),
layer_key=req.layer_key,
highlight_diff=req.highlight_diff,
output_dir=req.output_dir,
figure_format=req.figure_format,
pair_index=req.pair_index,
)
except Exception as e:
# Log the full trace for debugging
logger.exception("❌ Error in visualize_pca_endpoint")
# And return the message to the client
raise HTTPException(status_code=500, detail=str(e))
# 2. Verify that the file exists
if not filepath or not os.path.isfile(filepath):
raise HTTPException(status_code=500, detail="Image file not found after generation")
# 3. Return the file directly to the client
return FileResponse(
path=filepath,
media_type=f"image/{req.figure_format}",
filename=os.path.basename(filepath),
headers={"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'},
)
@router.post("/mean-diff", response_class=FileResponse)
async def visualize_mean_diff_endpoint(req: VisualizeMeanDiffRequest):
"""
Receives the parameters, calls the wrapper for optipfair.bias.visualize_mean_differences,
and returns the resulting PNG/SVG image.
"""
try:
filepath = run_visualize_mean_diff(
model_name=req.model_name,
prompt_pair=tuple(req.prompt_pair),
layer_type=req.layer_type, # Changed from layer_key to layer_type
figure_format=req.figure_format,
output_dir=req.output_dir,
pair_index=req.pair_index,
)
except Exception as e:
# Log the full trace for debugging
logger.exception("Error in mean-diff endpoint")
raise HTTPException(status_code=500, detail=str(e))
# Verify that the file exists
if not os.path.isfile(filepath):
raise HTTPException(status_code=500, detail="Image file not found")
# Return the file directly to the client
return FileResponse(
path=filepath,
media_type=f"image/{req.figure_format}",
filename=os.path.basename(filepath),
headers={"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'}
)
@router.post("/heatmap", response_class=FileResponse)
async def visualize_heatmap_endpoint(req: VisualizeHeatmapRequest):
"""
Receives the parameters, calls the wrapper for optipfair.bias.visualize_heatmap,
and returns the resulting PNG/SVG image.
"""
try:
filepath = run_visualize_heatmap(
model_name=req.model_name,
prompt_pair=tuple(req.prompt_pair),
layer_key=req.layer_key,
figure_format=req.figure_format,
output_dir=req.output_dir,
)
except Exception as e:
# Log the full trace for debugging
logger.exception("Error in heatmap endpoint")
raise HTTPException(status_code=500, detail=str(e))
# Verify that the file exists
if not os.path.isfile(filepath):
raise HTTPException(status_code=500, detail="Image file not found")
# Return the file directly to the client
return FileResponse(
path=filepath,
media_type=f"image/{req.figure_format}",
filename=os.path.basename(filepath),
headers={"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'}
) |