croissant-editor / views /record_sets.py
marcenacp's picture
Update
dc92053
raw
history blame
13.1 kB
from typing import Any
import numpy as np
import pandas as pd
from rdflib import term
import streamlit as st
from core.query_params import expand_record_set
from core.query_params import is_record_set_expanded
from core.state import Field
from core.state import Metadata
from core.state import RecordSet
from core.state import SelectedRecordSet
from events.record_sets import handle_record_set_change
from events.record_sets import RecordSetEvent
import mlcroissant as mlc
from utils import needed_field
from views.source import FieldEvent
from views.source import handle_field_change
from views.source import render_references
from views.source import render_source
DATA_TYPES = [
mlc.DataType.TEXT,
mlc.DataType.FLOAT,
mlc.DataType.INTEGER,
mlc.DataType.BOOL,
mlc.DataType.URL,
]
def _handle_close_fields():
st.session_state[SelectedRecordSet] = None
def _handle_on_click_field(
record_set_key: int,
record_set: RecordSet,
):
st.session_state[SelectedRecordSet] = SelectedRecordSet(
record_set_key=record_set_key,
record_set=record_set,
)
def _data_editor_key(record_set_key: int, record_set: RecordSet) -> str:
return f"{record_set_key}-{record_set.name}-dataframe"
def _get_possible_sources(metadata: Metadata) -> list[str]:
possible_sources: list[str] = []
for resource in metadata.distribution:
possible_sources.append(resource.name)
for record_set in metadata.record_sets:
for field in record_set.fields:
possible_sources.append(f"{record_set.name}/{field.name}")
return possible_sources
LeftOrRight = tuple[str, str]
Join = tuple[LeftOrRight, LeftOrRight]
def _find_left_or_right(source: mlc.Source) -> LeftOrRight:
uid = source.uid
if "/" in uid:
parts = uid.split("/")
return (parts[0], parts[1])
elif source.extract.column:
return (uid, source.extract.column)
elif source.extract.json_path:
return (uid, source.extract.json_path)
elif source.extract.file_property:
return (uid, source.extract.file_property)
else:
return (uid, None)
def _find_joins(fields: list[Field]) -> set[Join]:
"""Finds the existing joins in the fields."""
joins: set[Join] = set()
for field in fields:
if field.source and field.references:
left = _find_left_or_right(field.source)
right = _find_left_or_right(field.references)
joins.add((left, right))
return joins
def _handle_create_record_set():
metadata: Metadata = st.session_state[Metadata]
metadata.add_record_set(RecordSet(name="new-record-set", description=""))
def _handle_fields_change(
record_set_key: int, record_set: RecordSet, params: dict[str, Any]
):
expand_record_set(record_set=record_set)
data_editor_key = _data_editor_key(record_set_key, record_set)
result = st.session_state[data_editor_key]
# `result` has the following structure:
# ```
# {'edited_rows': {1: {}}, 'added_rows': [], 'deleted_rows': []}
# ```
fields = record_set.fields
for field_key in result["edited_rows"]:
field = fields[field_key]
new_fields = result["edited_rows"][field_key]
for new_field, new_value in new_fields.items():
if new_field == FieldDataFrame.NAME:
field.name = new_value
elif new_field == FieldDataFrame.DESCRIPTION:
field.description = new_value
elif new_field == FieldDataFrame.DATA_TYPE:
field.data_types = [new_value]
for added_row in result["added_rows"]:
field = Field(
name=added_row.get(FieldDataFrame.NAME),
description=added_row.get(FieldDataFrame.DESCRIPTION),
data_types=[added_row.get(FieldDataFrame.DATA_TYPE)],
source=mlc.Source(
uid="foo",
node_type="distribution",
extract=mlc.Extract(column=""),
),
references=mlc.Source(),
)
st.session_state[Metadata].add_field(record_set_key, field)
for field_key in result["deleted_rows"]:
st.session_state[Metadata].remove_field(record_set_key, field_key)
class FieldDataFrame:
"""Names of the columns in the pd.DataFrame for `fields`."""
NAME = "Name"
DESCRIPTION = "Description"
DATA_TYPE = "Data type"
SOURCE_UID = "Source"
SOURCE_EXTRACT = "Source extract"
SOURCE_TRANSFORM = "Source transform"
REFERENCE_UID = "Reference"
REFERENCE_EXTRACT = "Reference extract"
def render_record_sets():
col1, col2 = st.columns([1, 1])
with col1:
_render_left_panel()
with col2:
_render_right_panel()
def _render_left_panel():
"""Left panel: visualization of all RecordSets as expandable forms."""
distribution = st.session_state[Metadata].distribution
if not distribution:
st.markdown("Please add resources first.")
return
record_sets = st.session_state[Metadata].record_sets
record_set: RecordSet
for record_set_key, record_set in enumerate(record_sets):
title = f"**{record_set.name or '-'}** ({len(record_set.fields)} fields)"
prefix = f"record-set-{record_set_key}"
with st.expander(title, expanded=is_record_set_expanded(record_set)):
col1, col2 = st.columns([1, 3])
key = f"{prefix}-name"
col1.text_input(
needed_field("Name"),
placeholder="Name without special character.",
key=key,
value=record_set.name,
on_change=handle_record_set_change,
args=(RecordSetEvent.NAME, record_set, key),
)
key = f"{prefix}-description"
col2.text_input(
"Description",
placeholder="Provide a clear description of the RecordSet.",
key=key,
value=record_set.description,
on_change=handle_record_set_change,
args=(RecordSetEvent.DESCRIPTION, record_set, key),
)
key = f"{prefix}-is-enumeration"
st.checkbox(
"Whether the RecordSet is an enumeration",
key=key,
value=record_set.is_enumeration,
on_change=handle_record_set_change,
args=(RecordSetEvent.IS_ENUMERATION, record_set, key),
)
joins = _find_joins(record_set.fields)
has_join = st.checkbox(
"Whether the RecordSet contains joins. To add a new join, add a"
f" field with a source in `{record_set.name}` and a reference to"
" another RecordSet or FileSet/FileObject.",
key=f"{prefix}-has-joins",
value=bool(joins),
disabled=True,
)
if has_join:
for left, right in joins:
col1, col2, _, col4, col5 = st.columns([2, 2, 1, 2, 2])
col1.text_input(
"Left join",
disabled=True,
value=left[0],
key=f"{prefix}-left-join-{left[0]}-{left[1]}",
)
col2.text_input(
"Left key",
disabled=True,
value=left[1],
key=f"{prefix}-left-key-{left[0]}-{left[1]}",
)
col4.text_input(
"Right join",
disabled=True,
value=right[0],
key=f"{prefix}-right-join-{right[0]}-{right[1]}",
)
col5.text_input(
"Right key",
disabled=True,
value=right[1],
key=f"{prefix}-right-key-{right[0]}-{right[1]}",
)
names = [field.name for field in record_set.fields]
descriptions = [field.description for field in record_set.fields]
# TODO(https://github.com/mlcommons/croissant/issues/350): Allow to display
# several data types, not only the first.
data_types = [
field.data_types[0] if field.data_types else None
for field in record_set.fields
]
fields = pd.DataFrame(
{
FieldDataFrame.NAME: names,
FieldDataFrame.DESCRIPTION: descriptions,
FieldDataFrame.DATA_TYPE: data_types,
},
dtype=np.str_,
)
data_editor_key = _data_editor_key(record_set_key, record_set)
st.markdown(
f"{needed_field('Fields')} (add/delete fields by directly editing the"
" table)"
)
st.data_editor(
fields,
# There is a bug with `st.data_editor` when the df is empty.
use_container_width=not fields.empty,
num_rows="dynamic",
key=data_editor_key,
column_config={
FieldDataFrame.NAME: st.column_config.TextColumn(
FieldDataFrame.NAME,
help="Name of the field",
required=True,
),
FieldDataFrame.DESCRIPTION: st.column_config.TextColumn(
FieldDataFrame.DESCRIPTION,
help="Description of the field",
required=False,
),
FieldDataFrame.DATA_TYPE: st.column_config.SelectboxColumn(
FieldDataFrame.DATA_TYPE,
help="The Croissant type",
options=DATA_TYPES,
required=True,
),
},
on_change=_handle_fields_change,
args=(record_set_key, record_set),
)
st.button(
"Edit fields details",
key=f"{prefix}-show-fields",
on_click=_handle_on_click_field,
args=(record_set_key, record_set),
)
st.button(
"Create a new RecordSet",
key=f"create-new-record-set",
type="primary",
on_click=_handle_create_record_set,
)
def _render_right_panel():
"""Right panel: visualization of the clicked Field."""
metadata: Metadata = st.session_state.get(Metadata)
selected: SelectedRecordSet = st.session_state.get(SelectedRecordSet)
if not selected:
return
record_set = selected.record_set
record_set_key = selected.record_set_key
with st.expander("**Fields**", expanded=True):
for field_key, field in enumerate(record_set.fields):
prefix = f"{record_set_key}-{field.name}-{field_key}"
col1, col2, col3 = st.columns([1, 1, 1])
key = f"{prefix}-name"
col1.text_input(
needed_field("Name"),
placeholder="Name without special character.",
key=key,
value=field.name,
on_change=handle_field_change,
args=(FieldEvent.NAME, field, key),
)
key = f"{prefix}-description"
col2.text_input(
"Description",
placeholder="Provide a clear description of the RecordSet.",
key=key,
on_change=handle_field_change,
value=field.description,
args=(FieldEvent.DESCRIPTION, field, key),
)
if field.data_types:
data_type = field.data_types[0]
if isinstance(data_type, str):
data_type = term.URIRef(data_type)
if data_type in DATA_TYPES:
data_type_index = DATA_TYPES.index(data_type)
else:
data_type_index = None
else:
data_type_index = None
key = f"{prefix}-datatypes"
col3.selectbox(
needed_field("Data type"),
index=data_type_index,
options=DATA_TYPES,
key=key,
on_change=handle_field_change,
args=(FieldEvent.DATA_TYPE, field, key),
)
possible_sources = _get_possible_sources(metadata)
render_source(
record_set_key, record_set, field, field_key, possible_sources
)
render_references(
record_set_key, record_set, field, field_key, possible_sources
)
st.divider()
st.button(
"Close",
key=f"{record_set.name}-{record_set_key}-close-fields",
type="primary",
on_click=_handle_close_fields,
)