File size: 4,418 Bytes
c2d0da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1195123
c2d0da9
1195123
 
448dbfb
c2d0da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import csv
import io
import json
import os
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import List

import pandas as pd
from fastapi import Response
from modal import web_endpoint
import modal
from pydantic import BaseModel

from rating import compute_mle_elo

# -----------------------
# Data Model Definition
# -----------------------
class ExperienceEnum(int, Enum):
    novice = 1
    intermediate = 2
    expert = 3

class Winner(str, Enum):
    model_a = "model_a"
    model_b = "model_b"
    tie = "tie"


class Model(str, Enum):
    porestar_deepfault_unet_baseline_no_augment = "porestar/deepfault-unet-baseline-no-augment"
    porestar_deepfault_unet_baseline_weak_augment = "porestar/deepfault-unet-baseline-weak-augment"
    porestar_deepfault_unet_baseline_full_augment = "porestar/deepfault-unet-baseline-full-augment"

class Battle(BaseModel):
    model_a: Model
    model_b: Model
    winner: Winner
    judge: str
    image_idx: int 
    experience: ExperienceEnum = ExperienceEnum.novice
    tstamp: str = str(datetime.now())

class EloRating(BaseModel):
    model: Model
    elo_rating: float

# -----------------------
# Modal Configuration
# -----------------------

# Create a volume to persist data
data_volume = modal.Volume.from_name("seisbase-data", create_if_missing=True)

JSON_FILE_PATH = Path("/data/battles.json")
RESULTS_FILE_PATH = Path("/data/ratings.csv")

app_image = modal.Image.debian_slim(python_version="3.10").pip_install("pandas", "scikit-learn", "tqdm", "sympy")

app = modal.App(
    image=app_image,
    name="seisbase-eval",
    volumes={"/data": data_volume},
)

def ensure_json_file():
    """Ensure the JSON file exists and is initialized with an empty array if necessary."""
    if not os.path.exists(JSON_FILE_PATH):
        JSON_FILE_PATH.parent.mkdir(parents=True, exist_ok=True)
        with open(JSON_FILE_PATH, "w") as f:
            json.dump([], f)

def append_to_json_file(data):
    """Append data to the JSON file."""
    ensure_json_file()
    try:
        with open(JSON_FILE_PATH, "r+") as f:
            try:
                battles = json.load(f)
            except json.JSONDecodeError:
                # Reset the file if corrupted
                battles = []
            battles.append(data)
            f.seek(0)
            json.dump(battles, f, indent=4)
            f.truncate()
    except Exception as e:
        raise RuntimeError(f"Failed to append data to JSON file: {e}")

def read_json_file():
    """Read data from the JSON file."""
    ensure_json_file()
    try:
        with open(JSON_FILE_PATH, "r") as f:
            try:
                return json.load(f)
            except json.JSONDecodeError:
                return []  # Return an empty list if the file is corrupted
    except Exception as e:
        raise RuntimeError(f"Failed to read JSON file: {e}")

@app.function()
@web_endpoint(method="POST", docs=True)
def add_battle(battle: Battle):
    """Add a new battle to the JSON file."""
    append_to_json_file(battle.dict())
    return {"status": "success", "battle": battle.dict()}


@app.function()
@web_endpoint(method="GET", docs=True)
def export_csv():
    """Fetch all battles and return as CSV."""
    battles = read_json_file()

    # Create CSV in memory
    output = io.StringIO()
    writer = csv.DictWriter(output, fieldnames=["model_a", "model_b", "winner", "judge", "imaged_idx", "experience", "tstamp"])
    writer.writeheader()
    writer.writerows(battles)

    csv_data = output.getvalue()
    return Response(content=csv_data, media_type="text/csv")

@app.function()
@web_endpoint(method="GET", docs=True)
def compute_ratings() -> List[EloRating]:
    """Compute ratings from battles."""
    battles = pd.read_json(JSON_FILE_PATH, dtype=[str, str, str, str, int, int, str]).sort_values(ascending=True, by=["tstamp"]).reset_index(drop=True)
    elo_mle_ratings = compute_mle_elo(battles)
    elo_mle_ratings.to_csv(RESULTS_FILE_PATH)
    
    df = pd.read_csv(RESULTS_FILE_PATH)
    df.columns = ["Model", "Elo rating"]
    df = df.sort_values("Elo rating", ascending=False).reset_index(drop=True)
    scores = []
    for i in range(len(df)):
        scores.append(EloRating(model=df["Model"][i], elo_rating=df["Elo rating"][i]))
    return scores

@app.local_entrypoint()
def main():
    print("Local entrypoint running. Check endpoints for functionality.")