|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
To run this script, from the root of the repo. Make sure to have Flask installed
|
|
|
|
FLASK_DEBUG=1 FLASK_APP=scripts.mos flask run -p 4567
|
|
# or if you have gunicorn
|
|
gunicorn -w 4 -b 127.0.0.1:8895 -t 120 'scripts.mos:app' --access-logfile -
|
|
|
|
"""
|
|
from collections import defaultdict
|
|
from functools import wraps
|
|
from hashlib import sha1
|
|
import json
|
|
import math
|
|
from pathlib import Path
|
|
import random
|
|
import typing as tp
|
|
|
|
from flask import Flask, redirect, render_template, request, session, url_for
|
|
|
|
from audiocraft import train
|
|
from audiocraft.utils.samples.manager import get_samples_for_xps
|
|
|
|
|
|
SAMPLES_PER_PAGE = 8
|
|
MAX_RATING = 5
|
|
storage = Path(train.main.dora.dir / 'mos_storage')
|
|
storage.mkdir(exist_ok=True)
|
|
surveys = storage / 'surveys'
|
|
surveys.mkdir(exist_ok=True)
|
|
magma_root = Path(train.__file__).parent.parent
|
|
app = Flask('mos', static_folder=str(magma_root / 'scripts/static'),
|
|
template_folder=str(magma_root / 'scripts/templates'))
|
|
app.secret_key = b'audiocraft makes the best songs'
|
|
|
|
|
|
def normalize_path(path: Path):
|
|
"""Just to make path a bit nicer, make them relative to the Dora root dir.
|
|
"""
|
|
path = path.resolve()
|
|
dora_dir = train.main.dora.dir.resolve() / 'xps'
|
|
return path.relative_to(dora_dir)
|
|
|
|
|
|
def get_full_path(normalized_path: Path):
|
|
"""Revert `normalize_path`.
|
|
"""
|
|
return train.main.dora.dir.resolve() / 'xps' / normalized_path
|
|
|
|
|
|
def get_signature(xps: tp.List[str]):
|
|
"""Return a signature for a list of XP signatures.
|
|
"""
|
|
return sha1(json.dumps(xps).encode()).hexdigest()[:10]
|
|
|
|
|
|
def ensure_logged(func):
|
|
"""Ensure user is logged in.
|
|
"""
|
|
@wraps(func)
|
|
def _wrapped(*args, **kwargs):
|
|
user = session.get('user')
|
|
if user is None:
|
|
return redirect(url_for('login', redirect_to=request.url))
|
|
return func(*args, **kwargs)
|
|
return _wrapped
|
|
|
|
|
|
@app.route('/login', methods=['GET', 'POST'])
|
|
def login():
|
|
"""Login user if not already, then redirect.
|
|
"""
|
|
user = session.get('user')
|
|
if user is None:
|
|
error = None
|
|
if request.method == 'POST':
|
|
user = request.form['user']
|
|
if not user:
|
|
error = 'User cannot be empty'
|
|
if user is None or error:
|
|
return render_template('login.html', error=error)
|
|
assert user
|
|
session['user'] = user
|
|
redirect_to = request.args.get('redirect_to')
|
|
if redirect_to is None:
|
|
redirect_to = url_for('index')
|
|
return redirect(redirect_to)
|
|
|
|
|
|
@app.route('/', methods=['GET', 'POST'])
|
|
@ensure_logged
|
|
def index():
|
|
"""Offer to create a new study.
|
|
"""
|
|
errors = []
|
|
if request.method == 'POST':
|
|
xps_or_grids = [part.strip() for part in request.form['xps'].split()]
|
|
xps = set()
|
|
for xp_or_grid in xps_or_grids:
|
|
xp_path = train.main.dora.dir / 'xps' / xp_or_grid
|
|
if xp_path.exists():
|
|
xps.add(xp_or_grid)
|
|
continue
|
|
grid_path = train.main.dora.dir / 'grids' / xp_or_grid
|
|
if grid_path.exists():
|
|
for child in grid_path.iterdir():
|
|
if child.is_symlink():
|
|
xps.add(child.name)
|
|
continue
|
|
errors.append(f'{xp_or_grid} is neither an XP nor a grid!')
|
|
assert xps or errors
|
|
blind = 'true' if request.form.get('blind') == 'on' else 'false'
|
|
xps = list(xps)
|
|
if not errors:
|
|
signature = get_signature(xps)
|
|
manifest = {
|
|
'xps': xps,
|
|
}
|
|
survey_path = surveys / signature
|
|
survey_path.mkdir(exist_ok=True)
|
|
with open(survey_path / 'manifest.json', 'w') as f:
|
|
json.dump(manifest, f, indent=2)
|
|
return redirect(url_for('survey', blind=blind, signature=signature))
|
|
return render_template('index.html', errors=errors)
|
|
|
|
|
|
@app.route('/survey/<signature>', methods=['GET', 'POST'])
|
|
@ensure_logged
|
|
def survey(signature):
|
|
success = request.args.get('success', False)
|
|
seed = int(request.args.get('seed', 4321))
|
|
blind = request.args.get('blind', 'false') in ['true', 'on', 'True']
|
|
exclude_prompted = request.args.get('exclude_prompted', 'false') in ['true', 'on', 'True']
|
|
exclude_unprompted = request.args.get('exclude_unprompted', 'false') in ['true', 'on', 'True']
|
|
max_epoch = int(request.args.get('max_epoch', '-1'))
|
|
survey_path = surveys / signature
|
|
assert survey_path.exists(), survey_path
|
|
|
|
user = session['user']
|
|
result_folder = survey_path / 'results'
|
|
result_folder.mkdir(exist_ok=True)
|
|
result_file = result_folder / f'{user}_{seed}.json'
|
|
|
|
with open(survey_path / 'manifest.json') as f:
|
|
manifest = json.load(f)
|
|
|
|
xps = [train.main.get_xp_from_sig(xp) for xp in manifest['xps']]
|
|
names, ref_name = train.main.get_names(xps)
|
|
|
|
samples_kwargs = {
|
|
'exclude_prompted': exclude_prompted,
|
|
'exclude_unprompted': exclude_unprompted,
|
|
'max_epoch': max_epoch,
|
|
}
|
|
matched_samples = get_samples_for_xps(xps, epoch=-1, **samples_kwargs)
|
|
models_by_id = {
|
|
id: [{
|
|
'xp': xps[idx],
|
|
'xp_name': names[idx],
|
|
'model_id': f'{xps[idx].sig}-{sample.id}',
|
|
'sample': sample,
|
|
'is_prompted': sample.prompt is not None,
|
|
'errors': [],
|
|
} for idx, sample in enumerate(samples)]
|
|
for id, samples in matched_samples.items()
|
|
}
|
|
experiments = [
|
|
{'xp': xp, 'name': names[idx], 'epoch': list(matched_samples.values())[0][idx].epoch}
|
|
for idx, xp in enumerate(xps)
|
|
]
|
|
|
|
keys = list(matched_samples.keys())
|
|
keys.sort()
|
|
rng = random.Random(seed)
|
|
rng.shuffle(keys)
|
|
model_ids = keys[:SAMPLES_PER_PAGE]
|
|
|
|
if blind:
|
|
for key in model_ids:
|
|
rng.shuffle(models_by_id[key])
|
|
|
|
ok = True
|
|
if request.method == 'POST':
|
|
all_samples_results = []
|
|
for id in model_ids:
|
|
models = models_by_id[id]
|
|
result = {
|
|
'id': id,
|
|
'is_prompted': models[0]['is_prompted'],
|
|
'models': {}
|
|
}
|
|
all_samples_results.append(result)
|
|
for model in models:
|
|
rating = request.form[model['model_id']]
|
|
if rating:
|
|
rating = int(rating)
|
|
assert rating <= MAX_RATING and rating >= 1
|
|
result['models'][model['xp'].sig] = rating
|
|
model['rating'] = rating
|
|
else:
|
|
ok = False
|
|
model['errors'].append('Please rate this model.')
|
|
if ok:
|
|
result = {
|
|
'results': all_samples_results,
|
|
'seed': seed,
|
|
'user': user,
|
|
'blind': blind,
|
|
'exclude_prompted': exclude_prompted,
|
|
'exclude_unprompted': exclude_unprompted,
|
|
}
|
|
print(result)
|
|
with open(result_file, 'w') as f:
|
|
json.dump(result, f)
|
|
seed = seed + 1
|
|
return redirect(url_for(
|
|
'survey', signature=signature, blind=blind, seed=seed,
|
|
exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted,
|
|
max_epoch=max_epoch, success=True))
|
|
|
|
ratings = list(range(1, MAX_RATING + 1))
|
|
return render_template(
|
|
'survey.html', ratings=ratings, blind=blind, seed=seed, signature=signature, success=success,
|
|
exclude_prompted=exclude_prompted, exclude_unprompted=exclude_unprompted, max_epoch=max_epoch,
|
|
experiments=experiments, models_by_id=models_by_id, model_ids=model_ids, errors=[],
|
|
ref_name=ref_name, already_filled=result_file.exists())
|
|
|
|
|
|
@app.route('/audio/<path:path>')
|
|
def audio(path: str):
|
|
full_path = Path('/') / path
|
|
assert full_path.suffix in [".mp3", ".wav"]
|
|
return full_path.read_bytes(), {'Content-Type': 'audio/mpeg'}
|
|
|
|
|
|
def mean(x):
|
|
return sum(x) / len(x)
|
|
|
|
|
|
def std(x):
|
|
m = mean(x)
|
|
return math.sqrt(sum((i - m)**2 for i in x) / len(x))
|
|
|
|
|
|
@app.route('/results/<signature>')
|
|
@ensure_logged
|
|
def results(signature):
|
|
|
|
survey_path = surveys / signature
|
|
assert survey_path.exists(), survey_path
|
|
result_folder = survey_path / 'results'
|
|
result_folder.mkdir(exist_ok=True)
|
|
|
|
|
|
ratings_per_model = defaultdict(list)
|
|
users = []
|
|
for result_file in result_folder.iterdir():
|
|
if result_file.suffix != '.json':
|
|
continue
|
|
with open(result_file) as f:
|
|
results = json.load(f)
|
|
users.append(results['user'])
|
|
for result in results['results']:
|
|
for sig, rating in result['models'].items():
|
|
ratings_per_model[sig].append(rating)
|
|
|
|
fmt = '{:.2f}'
|
|
models = []
|
|
for model in sorted(ratings_per_model.keys()):
|
|
ratings = ratings_per_model[model]
|
|
|
|
models.append({
|
|
'sig': model,
|
|
'samples': len(ratings),
|
|
'mean_rating': fmt.format(mean(ratings)),
|
|
|
|
|
|
'std_rating': fmt.format(1.96 * std(ratings) / len(ratings)**0.5),
|
|
})
|
|
return render_template('results.html', signature=signature, models=models, users=users)
|
|
|