Pavithiran commited on
Commit
e8e15b2
·
verified ·
1 Parent(s): 36b8fdb

Update sagan_inference.py

Browse files
Files changed (1) hide show
  1. sagan_inference.py +32 -41
sagan_inference.py CHANGED
@@ -1,21 +1,11 @@
1
- import torch
2
- import numpy as np
3
- import librosa
4
- from huggingface_hub import hf_hub_download
5
- from sagan_model import SAGANModel # your model definition
6
 
7
- ### 1) Download & load your SAGAN weights from your HF repo ###
8
- SAGAN_WEIGHTS_PATH = hf_hub_download(
9
- repo_id="Pavithiran/SAGAN", # ← replace with your HF namespace
10
- filename="sagan_weights.pth"
11
- )
12
- model = SAGANModel()
13
- state_dict = torch.load(SAGAN_WEIGHTS_PATH, map_location="cpu")
14
- model.load_state_dict(state_dict)
15
- model.eval()
16
 
17
  ### 2) Age-group Z-score stats (proxy values from literature) ###
18
- import math
19
  STATS = {
20
  "kindergarten": {
21
  "pitch": {"mu": 30.0, "sigma": 29.0}, # Wise & Sloboda (2008)
@@ -34,32 +24,33 @@ STATS = {
34
  },
35
  }
36
 
37
- def sigmoid(z: float) -> float:
38
- return 1 / (1 + math.exp(-z))
 
 
 
39
 
40
- def z_score_standardize(raw_metrics: dict, age_group: str) -> dict:
41
- if age_group not in STATS:
42
- raise ValueError(f"Unknown age_group '{age_group}'")
43
- stats = STATS[age_group]
44
- out = {}
45
- for key, raw in raw_metrics.items():
46
- μ, σ = stats[key]["mu"], stats[key]["sigma"]
47
- z = (raw - μ) / σ
48
- out[key] = round(sigmoid(z), 3)
49
- return out
50
 
51
- def run_sagan(wav_path: str) -> dict:
52
- """
53
- 1) Load audio
54
- 2) Run SAGANModel.evaluate → returns {'pitch_accuracy', 'rhythm_consistency', 'timbre_score'}
55
- 3) Return raw dict
56
- """
57
- y, sr = librosa.load(wav_path, sr=16000, mono=True)
 
 
 
 
58
  with torch.no_grad():
59
- metrics = model.evaluate(y, sr)
60
- # Ensure keys:
61
- return {
62
- "pitch": float(metrics.get("pitch_accuracy", metrics[0])),
63
- "rhythm": float(metrics.get("rhythm_consistency", metrics[1])),
64
- "timbre": float(metrics.get("timbre_score", metrics[2])),
65
- }
 
1
+ # sagan_inference.py
 
 
 
 
2
 
3
+ import torch
4
+ import torchaudio
5
+ import math
6
+ from sagan_model import SAGANModel
 
 
 
 
 
7
 
8
  ### 2) Age-group Z-score stats (proxy values from literature) ###
 
9
  STATS = {
10
  "kindergarten": {
11
  "pitch": {"mu": 30.0, "sigma": 29.0}, # Wise & Sloboda (2008)
 
24
  },
25
  }
26
 
27
+ def z_score_standardize(waveform: torch.Tensor, age_group: str) -> torch.Tensor:
28
+ stats = STATS.get(age_group, STATS["adult"])
29
+ mu, sigma = stats["pitch"]["mu"], stats["pitch"]["sigma"]
30
+ # example for pitch; repeat for rhythm/timbre as needed
31
+ return (waveform - mu) / (sigma + 1e-9)
32
 
33
+ def run_sagan(audio_path: str, checkpoint_path: str, device='cpu'):
34
+ # 1) Load audio
35
+ waveform, sr = torchaudio.load(audio_path)
36
+ waveform = z_score_standardize(waveform).to(device)
 
 
 
 
 
 
37
 
38
+ # 2) Instantiate model & load weights
39
+ model = SAGANModel(z_dim=128).to(device)
40
+ ckpt = torch.load(checkpoint_path, map_location=device)
41
+ model.load_state_dict(ckpt['model_state_dict'])
42
+ model.eval()
43
+
44
+ # 3) Prepare latent vector from audio (example: mean-pool + linear proj)
45
+ # _Here you’ll replace `encode_to_z` with your custom feature extractor_
46
+ z = encode_to_z(waveform).unsqueeze(-1).unsqueeze(-1) # -> (1, 128, 1, 1)
47
+
48
+ # 4) Generate
49
  with torch.no_grad():
50
+ fake_img = model(z) # -> (1, 3, 64, 64) for a 64×64 SAGAN
51
+ return fake_img
52
+
53
+ # Placeholder: your own mapping from waveform → z
54
+ def encode_to_z(wf):
55
+ # e.g., a small CNN or an MLP extracting 128-d features from audio
56
+ return wf.mean(dim=-1).mean(dim=-1).unsqueeze(0).repeat(1,128)