Spaces:
Running
Running
# routers/visualize.py | |
import os | |
import logging | |
from fastapi import APIRouter, HTTPException | |
from fastapi.responses import FileResponse | |
from schemas.visualize import ( | |
VisualizePCARequest, | |
VisualizeMeanDiffRequest, | |
VisualizeHeatmapRequest, | |
) | |
from utils.visualize_pca import ( | |
run_visualize_pca, | |
run_visualize_mean_diff, | |
run_visualize_heatmap, | |
) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
router = APIRouter( | |
prefix="/visualize", | |
tags=["visualization"], | |
) | |
async def visualize_pca_endpoint(req: VisualizePCARequest): | |
""" | |
Receives the parameters, calls the wrapper for optipfair.bias.visualize_pca, | |
and returns the resulting PNG/SVG image. | |
""" | |
# 1. Execute the image generation and get the file path | |
try: | |
filepath = run_visualize_pca( | |
model_name=req.model_name, | |
prompt_pair=tuple(req.prompt_pair), | |
layer_key=req.layer_key, | |
highlight_diff=req.highlight_diff, | |
output_dir=req.output_dir, | |
figure_format=req.figure_format, | |
pair_index=req.pair_index, | |
) | |
except Exception as e: | |
# Log the full trace for debugging | |
logger.exception("β Error in visualize_pca_endpoint") | |
# And return the message to the client | |
raise HTTPException(status_code=500, detail=str(e)) | |
# 2. Verify that the file exists | |
if not filepath or not os.path.isfile(filepath): | |
raise HTTPException(status_code=500, detail="Image file not found after generation") | |
# 3. Return the file directly to the client | |
return FileResponse( | |
path=filepath, | |
media_type=f"image/{req.figure_format}", | |
filename=os.path.basename(filepath), | |
headers={"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'}, | |
) | |
async def visualize_mean_diff_endpoint(req: VisualizeMeanDiffRequest): | |
""" | |
Receives the parameters, calls the wrapper for optipfair.bias.visualize_mean_differences, | |
and returns the resulting PNG/SVG image. | |
""" | |
try: | |
filepath = run_visualize_mean_diff( | |
model_name=req.model_name, | |
prompt_pair=tuple(req.prompt_pair), | |
layer_type=req.layer_type, # Changed from layer_key to layer_type | |
figure_format=req.figure_format, | |
output_dir=req.output_dir, | |
pair_index=req.pair_index, | |
) | |
except Exception as e: | |
# Log the full trace for debugging | |
logger.exception("Error in mean-diff endpoint") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Verify that the file exists | |
if not os.path.isfile(filepath): | |
raise HTTPException(status_code=500, detail="Image file not found") | |
# Return the file directly to the client | |
return FileResponse( | |
path=filepath, | |
media_type=f"image/{req.figure_format}", | |
filename=os.path.basename(filepath), | |
headers={"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'} | |
) | |
async def visualize_heatmap_endpoint(req: VisualizeHeatmapRequest): | |
""" | |
Receives the parameters, calls the wrapper for optipfair.bias.visualize_heatmap, | |
and returns the resulting PNG/SVG image. | |
""" | |
try: | |
filepath = run_visualize_heatmap( | |
model_name=req.model_name, | |
prompt_pair=tuple(req.prompt_pair), | |
layer_key=req.layer_key, | |
figure_format=req.figure_format, | |
output_dir=req.output_dir, | |
) | |
except Exception as e: | |
# Log the full trace for debugging | |
logger.exception("Error in heatmap endpoint") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Verify that the file exists | |
if not os.path.isfile(filepath): | |
raise HTTPException(status_code=500, detail="Image file not found") | |
# Return the file directly to the client | |
return FileResponse( | |
path=filepath, | |
media_type=f"image/{req.figure_format}", | |
filename=os.path.basename(filepath), | |
headers={"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'} | |
) |