File size: 11,330 Bytes
cb5b71d
 
 
 
 
 
 
bbea1cc
cb5b71d
 
 
 
 
 
bbea1cc
 
cb5b71d
bbea1cc
 
cb5b71d
 
bbea1cc
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
bbea1cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
041af8a
 
 
 
 
 
 
 
 
 
 
 
bbea1cc
 
 
 
 
 
041af8a
 
 
 
 
 
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
5a782ad
cb5b71d
041af8a
 
bbea1cc
 
 
 
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
"""Streamlit session state.

In the future, this could be the serialization format between front and back.
"""

from __future__ import annotations

import base64
import dataclasses
import datetime
from typing import Any

from etils import epath
import pandas as pd
import requests
import streamlit as st

from core.constants import OAUTH_CLIENT_ID
from core.constants import OAUTH_CLIENT_SECRET
from core.constants import PAST_PROJECTS_PATH
from core.constants import PROJECT_FOLDER_PATTERN
from core.constants import REDIRECT_URI
import mlcroissant as mlc


def create_class(mlc_class: type, instance: Any, **kwargs) -> Any:
    """Creates the mlcroissant class `mlc_class` from the editor `instance`."""
    fields = dataclasses.fields(mlc_class)
    params: dict[str, Any] = {}
    for field in fields:
        name = field.name
        if hasattr(instance, name) and name not in kwargs:
            params[name] = getattr(instance, name)
    return mlc_class(**params, **kwargs)


@dataclasses.dataclass
class User:
    """The connected user."""

    access_token: str
    id_token: str
    username: str

    @classmethod
    def connect(cls, code: str):
        credentials = base64.b64encode(
            f"{OAUTH_CLIENT_ID}:{OAUTH_CLIENT_SECRET}".encode()
        ).decode()
        headers = {
            "Authorization": f"Basic {credentials}",
        }
        data = {
            "client_id": OAUTH_CLIENT_ID,
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": REDIRECT_URI,
        }
        url = "https://huggingface.co/oauth/token"
        response = requests.post(url, data=data, headers=headers)
        if response.status_code == 200:
            response = response.json()
            access_token = response.get("access_token")
            id_token = response.get("id_token")
            if access_token and id_token:
                url = "https://huggingface.co/oauth/userinfo"
                headers = {"Authorization": f"Bearer {access_token}"}
                response = requests.get(url, headers=headers)
                if response.status_code == 200:
                    response = response.json()
                    username = response.get("preferred_username")
                    if username:
                        return User(
                            access_token=access_token,
                            username=username,
                            id_token=id_token,
                        )
        raise Exception(
            f"Could not connect to Hugging Face. Please, go to {REDIRECT_URI}."
            f" ({response=})."
        )


@st.cache_data(ttl=datetime.timedelta(hours=1))
def get_cached_user():
    """Caches user in session_state."""
    return st.session_state.get(User)


class CurrentStep:
    """Holds all major state variables for the application."""

    splash = "splash"
    editor = "editor"


@dataclasses.dataclass
class CurrentProject:
    """The selected project."""

    path: epath.Path

    @classmethod
    def create_new(cls) -> CurrentProject | None:
        timestamp = datetime.datetime.now().strftime(PROJECT_FOLDER_PATTERN)
        user = get_cached_user()
        if user is None and OAUTH_CLIENT_ID:
            return None
        else:
            path = PAST_PROJECTS_PATH(user)
            return CurrentProject(path=path / timestamp)


class SelectedResource:
    """The selected FileSet or FileObject on the `Resources` page."""

    pass


@dataclasses.dataclass
class SelectedRecordSet:
    """The selected RecordSet on the `RecordSets` page."""

    record_set_key: int
    record_set: RecordSet


@dataclasses.dataclass
class FileObject:
    """FileObject analogue for editor"""

    name: str | None = None
    description: str | None = None
    contained_in: list[str] | None = dataclasses.field(default_factory=list)
    content_size: str | None = None
    content_url: str | None = None
    encoding_format: str | None = None
    sha256: str | None = None
    df: pd.DataFrame | None = None
    rdf: mlc.Rdf = dataclasses.field(default_factory=mlc.Rdf)


@dataclasses.dataclass
class FileSet:
    """FileSet analogue for editor"""

    contained_in: list[str] = dataclasses.field(default_factory=list)
    description: str | None = None
    encoding_format: str | None = ""
    includes: str | None = ""
    name: str = ""
    rdf: mlc.Rdf = dataclasses.field(default_factory=mlc.Rdf)


@dataclasses.dataclass
class Field:
    """Field analogue for editor"""

    name: str | None = None
    description: str | None = None
    data_types: str | list[str] | None = None
    source: mlc.Source | None = None
    rdf: mlc.Rdf = dataclasses.field(default_factory=mlc.Rdf)
    references: mlc.Source | None = None


@dataclasses.dataclass
class RecordSet:
    """Record Set analogue for editor"""

    name: str = ""
    data: Any = None
    description: str | None = None
    is_enumeration: bool | None = None
    key: str | list[str] | None = None
    fields: list[Field] = dataclasses.field(default_factory=list)
    rdf: mlc.Rdf = dataclasses.field(default_factory=mlc.Rdf)


@dataclasses.dataclass
class Metadata:
    """main croissant data object, helper functions exist to load and unload this into the mlcroissant version"""

    name: str = ""
    description: str | None = None
    citation: str | None = None
    license: str | None = ""
    url: str = ""
    distribution: list[FileObject | FileSet] = dataclasses.field(default_factory=list)
    record_sets: list[RecordSet] = dataclasses.field(default_factory=list)
    rdf: mlc.Rdf = dataclasses.field(default_factory=mlc.Rdf)

    def __bool__(self):
        return self.name != "" and self.url != ""

    def rename_distribution(self, old_name: str, new_name: str):
        """Renames a resource by changing all the references to this resource."""
        # Update other resources:
        for i, resource in enumerate(self.distribution):
            contained_in = resource.contained_in
            if contained_in and old_name in contained_in:
                self.distribution[i].contained_in = [
                    new_name if name == old_name else name for name in contained_in
                ]
        # Updating source/references works just as with RecordSets.
        self.rename_record_set(old_name, new_name)

    def rename_record_set(self, old_name: str, new_name: str):
        """Renames a RecordSet by changing all the references to this RecordSet."""
        for i, record_set in enumerate(self.record_sets):
            for j, field in enumerate(record_set.fields):
                # Update source
                source = field.source
                if source and source.uid and source.uid.startswith(old_name):
                    new_uid = source.uid.replace(old_name, new_name, 1)
                    self.record_sets[i].fields[j].source.uid = new_uid
                # Update references
                references = field.references
                if (
                    references
                    and references.uid
                    and references.uid.startswith(old_name)
                ):
                    new_uid = references.uid.replace(old_name, new_name, 1)
                    self.record_sets[i].fields[j].references.uid = new_uid

    def rename_field(self, old_name: str, new_name: str):
        """Renames a field by changing all the references to this field."""
        for i, record_set in enumerate(self.record_sets):
            for j, field in enumerate(record_set.fields):
                # Update source
                source = field.source
                # The difference with RecordSet is the `.endswith` here:
                if (
                    source
                    and source.uid
                    and "/" in source.uid
                    and source.uid.endswith(old_name)
                ):
                    new_uid = source.uid.replace(old_name, new_name, 1)
                    self.record_sets[i].fields[j].source.uid = new_uid
                # Update references
                references = field.references
                if (
                    references
                    and references.uid
                    and "/" in references.uid
                    and references.uid.endswith(old_name)
                ):
                    new_uid = references.uid.replace(old_name, new_name, 1)
                    self.record_sets[i].fields[j].references.uid = new_uid

    def add_distribution(self, distribution: FileSet | FileObject) -> None:
        self.distribution.append(distribution)

    def remove_distribution(self, key: int) -> None:
        del self.distribution[key]

    def add_record_set(self, record_set: RecordSet) -> None:
        self.record_sets.append(record_set)

    def remove_record_set(self, key: int) -> None:
        del self.record_sets[key]

    def _find_record_set(self, record_set_key: int) -> RecordSet:
        if record_set_key >= len(self.record_sets):
            raise ValueError(f"Wrong index when finding a RecordSet: {record_set_key}")
        return self.record_sets[record_set_key]

    def add_field(self, record_set_key: int, field: Field) -> None:
        record_set = self._find_record_set(record_set_key)
        record_set.fields.append(field)

    def remove_field(self, record_set_key: int, field_key: int) -> None:
        record_set = self._find_record_set(record_set_key)
        if field_key >= len(record_set.fields):
            raise ValueError(f"Wrong index when removing field: {field_key}")
        del record_set.fields[field_key]

    def to_canonical(self) -> mlc.Metadata:
        distribution = []
        for file in self.distribution:
            if isinstance(file, FileObject):
                distribution.append(create_class(mlc.FileObject, file))
            elif isinstance(file, FileSet):
                distribution.append(create_class(mlc.FileSet, file))
        record_sets = []
        for record_set in self.record_sets:
            fields = []
            for field in record_set.fields:
                fields.append(create_class(mlc.Field, field))
            record_sets.append(create_class(mlc.RecordSet, record_set, fields=fields))
        return create_class(
            mlc.Metadata,
            self,
            distribution=distribution,
            record_sets=record_sets,
        )

    @classmethod
    def from_canonical(cls, canonical_metadata: mlc.Metadata) -> Metadata:
        distribution = []
        for file in canonical_metadata.distribution:
            if isinstance(file, mlc.FileObject):
                distribution.append(create_class(FileObject, file))
            else:
                distribution.append(create_class(FileSet, file))
        record_sets = []
        for record_set in canonical_metadata.record_sets:
            fields = []
            for field in record_set.fields:
                fields.append(create_class(Field, field))
            record_sets.append(
                create_class(
                    RecordSet,
                    record_set,
                    fields=fields,
                )
            )
        return create_class(
            cls,
            canonical_metadata,
            distribution=distribution,
            record_sets=record_sets,
        )