Pavithiran commited on
Commit
c98cb1d
·
verified ·
1 Parent(s): 9ee0fa2

Create sagan_inference.py

Browse files
Files changed (1) hide show
  1. sagan_inference.py +65 -0
sagan_inference.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import librosa
4
+ from huggingface_hub import hf_hub_download
5
+ from sagan_model import SAGANModel # your model definition
6
+
7
+ ### 1) Download & load your SAGAN weights from your HF repo ###
8
+ SAGAN_WEIGHTS_PATH = hf_hub_download(
9
+ repo_id="YOUR_USERNAME/sagan-space", # ← replace with your HF namespace
10
+ filename="sagan_weights.pth"
11
+ )
12
+ model = SAGANModel()
13
+ state_dict = torch.load(SAGAN_WEIGHTS_PATH, map_location="cpu")
14
+ model.load_state_dict(state_dict)
15
+ model.eval()
16
+
17
+ ### 2) Age-group Z-score stats (proxy values from literature) ###
18
+ import math
19
+ STATS = {
20
+ "kindergarten": {
21
+ "pitch": {"mu": 30.0, "sigma": 29.0}, # Wise & Sloboda (2008)
22
+ "rhythm": {"mu": 60.0, "sigma": 15.0}, # Demorest & Pfordresher (2015)
23
+ "timbre": {"mu": 0.65, "sigma": 0.10},
24
+ },
25
+ "grade_6": {
26
+ "pitch": {"mu": 43.0, "sigma": 26.0},
27
+ "rhythm": {"mu": 75.0, "sigma": 10.0},
28
+ "timbre": {"mu": 0.75, "sigma": 0.08},
29
+ },
30
+ "adult": {
31
+ "pitch": {"mu": 32.0, "sigma": 19.0},
32
+ "rhythm": {"mu": 80.0, "sigma": 8.0},
33
+ "timbre": {"mu": 0.85, "sigma": 0.05},
34
+ },
35
+ }
36
+
37
+ def sigmoid(z: float) -> float:
38
+ return 1 / (1 + math.exp(-z))
39
+
40
+ def z_score_standardize(raw_metrics: dict, age_group: str) -> dict:
41
+ if age_group not in STATS:
42
+ raise ValueError(f"Unknown age_group '{age_group}'")
43
+ stats = STATS[age_group]
44
+ out = {}
45
+ for key, raw in raw_metrics.items():
46
+ μ, σ = stats[key]["mu"], stats[key]["sigma"]
47
+ z = (raw - μ) / σ
48
+ out[key] = round(sigmoid(z), 3)
49
+ return out
50
+
51
+ def run_sagan(wav_path: str) -> dict:
52
+ """
53
+ 1) Load audio
54
+ 2) Run SAGANModel.evaluate → returns {'pitch_accuracy', 'rhythm_consistency', 'timbre_score'}
55
+ 3) Return raw dict
56
+ """
57
+ y, sr = librosa.load(wav_path, sr=16000, mono=True)
58
+ with torch.no_grad():
59
+ metrics = model.evaluate(y, sr)
60
+ # Ensure keys:
61
+ return {
62
+ "pitch": float(metrics.get("pitch_accuracy", metrics[0])),
63
+ "rhythm": float(metrics.get("rhythm_consistency", metrics[1])),
64
+ "timbre": float(metrics.get("timbre_score", metrics[2])),
65
+ }