File size: 4,386 Bytes
cb7223a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# routers/visualize.py
import os
import logging
from fastapi import APIRouter, HTTPException
from fastapi.responses import FileResponse
from schemas.visualize import (
    VisualizePCARequest,
    VisualizeMeanDiffRequest,
    VisualizeHeatmapRequest,
)
from utils.visualize_pca import (
    run_visualize_pca,
    run_visualize_mean_diff,
    run_visualize_heatmap,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

router = APIRouter(
    prefix="/visualize",
    tags=["visualization"],
)

@router.post(
    "/pca",
    summary="Generates and returns the PCA visualization of activations",
    response_class=FileResponse,
)
async def visualize_pca_endpoint(req: VisualizePCARequest):
    """
    Receives the parameters, calls the wrapper for optipfair.bias.visualize_pca,
    and returns the resulting PNG/SVG image.
    """
    # 1. Execute the image generation and get the file path
    try:
        filepath = run_visualize_pca(
            model_name=req.model_name,
            prompt_pair=tuple(req.prompt_pair),
            layer_key=req.layer_key,
            highlight_diff=req.highlight_diff,
            output_dir=req.output_dir,
            figure_format=req.figure_format,
            pair_index=req.pair_index,
        )
    except Exception as e:
        # Log the full trace for debugging
        logger.exception("❌ Error in visualize_pca_endpoint")
        # And return the message to the client
        raise HTTPException(status_code=500, detail=str(e))
    # 2. Verify that the file exists
    if not filepath or not os.path.isfile(filepath):
        raise HTTPException(status_code=500, detail="Image file not found after generation")

    # 3. Return the file directly to the client
    return FileResponse(
        path=filepath,
        media_type=f"image/{req.figure_format}",
        filename=os.path.basename(filepath),
        headers={"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'},
    )

@router.post("/mean-diff", response_class=FileResponse)
async def visualize_mean_diff_endpoint(req: VisualizeMeanDiffRequest):
    """
    Receives the parameters, calls the wrapper for optipfair.bias.visualize_mean_differences,
    and returns the resulting PNG/SVG image.
    """
    try:
        filepath = run_visualize_mean_diff(
            model_name=req.model_name,
            prompt_pair=tuple(req.prompt_pair),
            layer_type=req.layer_type,  # Changed from layer_key to layer_type
            figure_format=req.figure_format,
            output_dir=req.output_dir,
            pair_index=req.pair_index,
        )
    except Exception as e:
        # Log the full trace for debugging
        logger.exception("Error in mean-diff endpoint")
        raise HTTPException(status_code=500, detail=str(e))

    # Verify that the file exists
    if not os.path.isfile(filepath):
        raise HTTPException(status_code=500, detail="Image file not found")

    # Return the file directly to the client
    return FileResponse(
        path=filepath,
        media_type=f"image/{req.figure_format}",
        filename=os.path.basename(filepath),
        headers={"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'}
    )

@router.post("/heatmap", response_class=FileResponse)
async def visualize_heatmap_endpoint(req: VisualizeHeatmapRequest):
    """
    Receives the parameters, calls the wrapper for optipfair.bias.visualize_heatmap,
    and returns the resulting PNG/SVG image.
    """
    try:
        filepath = run_visualize_heatmap(
            model_name=req.model_name,
            prompt_pair=tuple(req.prompt_pair),
            layer_key=req.layer_key,
            figure_format=req.figure_format,
            output_dir=req.output_dir,
        )
    except Exception as e:
        # Log the full trace for debugging
        logger.exception("Error in heatmap endpoint")
        raise HTTPException(status_code=500, detail=str(e))

    # Verify that the file exists
    if not os.path.isfile(filepath):
        raise HTTPException(status_code=500, detail="Image file not found")

    # Return the file directly to the client
    return FileResponse(
        path=filepath,
        media_type=f"image/{req.figure_format}",
        filename=os.path.basename(filepath),
        headers={"Content-Disposition": f'inline; filename="{os.path.basename(filepath)}"'}
    )