File size: 1,342 Bytes
4aaf9d8
bd6e417
 
4aaf9d8
bd6e417
4aaf9d8
 
 
bd6e417
 
4aaf9d8
 
 
 
 
bd6e417
 
bead200
bd6e417
 
 
 
fac7be5
bd6e417
 
 
 
fac7be5
4aaf9d8
bd6e417
 
 
 
 
fac7be5
 
 
 
 
 
 
 
 
 
bd6e417
 
 
 
4aaf9d8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import streamlit as st
import streamlit_pianoroll
from datasets import load_dataset
from piano_dataset import PianoTasks
from fortepyan import MidiPiece


def main():
    st.write("# PIANO dataset task review")

    available_tasks = PianoTasks.list_tasks()
    task_name = st.selectbox(
        label="Select PIANO task",
        options=available_tasks,
    )
    piano_task = PianoTasks.get_task(task_name=task_name)

    dataset = load_dataset("epr-labs/maestro-sustain-v2", split="test+validation")
    record_idx = st.number_input(
        label="Record ID",
        min_value=0,
        max_value=len(dataset) - 1,
        value=43,
    )

    record = dataset[record_idx]
    piece = MidiPiece.from_huggingface(record)
    st.json(piece.source)

    piece_split = piano_task.prompt_target_split(piece.df)

    source_piece = MidiPiece(df=piece_split.source_df)
    target_piece = MidiPiece(df=piece_split.target_df)

    cols = st.columns(2)
    with cols[0]:
        st.write("## Prompt")
        streamlit_pianoroll.from_fortepyan(source_piece)

    with cols[1]:
        st.write("## Target")
        streamlit_pianoroll.from_fortepyan(target_piece)

    st.write("## Combined")
    streamlit_pianoroll.from_fortepyan(
        piece=source_piece,
        secondary_piece=target_piece,
    )


if __name__ == "__main__":
    main()