asoria HF staff commited on
Commit
dc70c7b
·
1 Parent(s): a893b55

some refactor

Browse files
Files changed (4) hide show
  1. .vscode/launch.json +14 -0
  2. app.py +21 -63
  3. src/templates.py +3 -11
  4. src/viewer_api.py +51 -0
.vscode/launch.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "Python Debugger: Current File",
6
+ "type": "debugpy",
7
+ "request": "launch",
8
+ "program": "${file}",
9
+ "console": "integratedTerminal",
10
+ "purpose": ["debug-test"],
11
+ "justMyCode": false
12
+ }
13
+ ]
14
+ }
app.py CHANGED
@@ -5,9 +5,7 @@ import logging
5
  import os
6
 
7
  import datamapplot
8
- import duckdb
9
  import numpy as np
10
- import requests
11
 
12
  from dotenv import load_dotenv
13
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
@@ -28,26 +26,28 @@ from transformers import (
28
 
29
  from src.hub import create_space_with_content
30
  from src.templates import REPRESENTATION_PROMPT, SPACE_REPO_CARD_CONTENT
 
 
 
 
 
 
 
31
 
 
32
  load_dotenv()
33
  HF_TOKEN = os.getenv("HF_TOKEN")
34
  assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
35
-
36
-
37
- EXPORTS_REPOSITORY = os.getenv("EXPORTS_REPOSITORY")
38
- assert (
39
- EXPORTS_REPOSITORY is not None
40
- ), "You need to set EXPORTS_REPOSITORY in your environment variables"
41
-
42
  MAX_ROWS = int(os.getenv("MAX_ROWS", "8_000"))
43
  CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "2_000"))
44
- DATASET_VIEWE_API_URL = "https://datasets-server.huggingface.co/"
45
  DATASETS_TOPICS_ORGANIZATION = os.getenv(
46
  "DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
47
  )
48
  USE_ARROW_STYLE = int(os.getenv("USE_ARROW_STYLE", "0"))
49
- USE_CUML = int(os.getenv("USE_CUML", "0"))
 
50
 
 
51
  if USE_CUML:
52
  from cuml.manifold import UMAP
53
  from cuml.cluster import HDBSCAN
@@ -55,14 +55,12 @@ else:
55
  from umap import UMAP
56
  from hdbscan import HDBSCAN
57
 
58
- USE_LLM_TEXT_GENERATION = int(os.getenv("USE_LLM_TEXT_GENERATION", "1"))
59
 
60
  logging.basicConfig(
61
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
62
  )
63
 
64
  api = HfApi(token=HF_TOKEN)
65
- session = requests.Session()
66
  sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
67
 
68
  # Representation model
@@ -98,41 +96,6 @@ else:
98
  vectorizer_model = CountVectorizer(stop_words="english")
99
 
100
 
101
- def get_split_rows(dataset, config, split):
102
- config_size = session.get(
103
- f"{DATASET_VIEWE_API_URL}/size?dataset={dataset}&config={config}",
104
- timeout=20,
105
- ).json()
106
- if "error" in config_size:
107
- raise Exception(f"Error fetching config size: {config_size['error']}")
108
- split_size = next(
109
- (s for s in config_size["size"]["splits"] if s["split"] == split),
110
- None,
111
- )
112
- if split_size is None:
113
- raise Exception(f"Error fetching split {split} in config {config}")
114
- return split_size["num_rows"]
115
-
116
-
117
- def get_parquet_urls(dataset, config, split):
118
- parquet_files = session.get(
119
- f"{DATASET_VIEWE_API_URL}/parquet?dataset={dataset}&config={config}&split={split}",
120
- timeout=20,
121
- ).json()
122
- if "error" in parquet_files:
123
- raise Exception(f"Error fetching parquet files: {parquet_files['error']}")
124
- parquet_urls = [file["url"] for file in parquet_files["parquet_files"]]
125
- logging.debug(f"Parquet files: {parquet_urls}")
126
- return ",".join(f"'{url}'" for url in parquet_urls)
127
-
128
-
129
- def get_docs_from_parquet(parquet_urls, column, offset, limit):
130
- SQL_QUERY = f"SELECT {column} FROM read_parquet([{parquet_urls}]) LIMIT {limit} OFFSET {offset};"
131
- df = duckdb.sql(SQL_QUERY).to_df()
132
- return df[column].tolist()
133
-
134
-
135
- # @spaces.GPU
136
  def calculate_embeddings(docs):
137
  return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
138
 
@@ -143,7 +106,6 @@ def calculate_n_neighbors_and_components(n_rows):
143
  return n_neighbors, n_components
144
 
145
 
146
- # @spaces.GPU
147
  def fit_model(docs, embeddings, n_neighbors, n_components):
148
  umap_model = UMAP(
149
  n_neighbors=n_neighbors,
@@ -254,18 +216,16 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
254
  reduced_embeddings_array = np.vstack(reduced_embeddings_list)
255
 
256
  topics_info = base_model.get_topic_info()
257
- all_topics, _ = base_model.transform(all_docs)
258
- all_topics = np.array(all_topics)
259
-
260
  sub_title = (
261
  f"Data map for the entire dataset ({limit} rows) using the column '{column}'"
262
  if full_processing
263
  else f"Data map for a sample of the dataset (first {limit} rows) using the column '{column}'"
264
  )
265
-
266
  topic_plot = (
267
  base_model.visualize_document_datamap(
268
  docs=all_docs,
 
269
  reduced_embeddings=reduced_embeddings_array,
270
  title=dataset,
271
  sub_title=sub_title,
@@ -291,7 +251,6 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
291
  title=dataset,
292
  )
293
  )
294
-
295
  rows_processed += len(docs)
296
  progress = min(rows_processed / limit, 1.0)
297
  logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
@@ -320,10 +279,10 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
320
  else:
321
  topic_plot.write_image(plot_png)
322
 
323
- all_topics, _ = base_model.transform(all_docs)
324
  topic_info = base_model.get_topic_info()
325
 
326
- topic_names = {row["Topic"]: row["Name"] for index, row in topic_info.iterrows()}
327
  topic_names_array = np.array(
328
  [
329
  topic_names.get(topic, "No Topic").split("_")[1].strip("-")
@@ -461,21 +420,20 @@ with gr.Blocks() as demo:
461
  text_column_dropdown: gr.Dropdown(label="Text column name"),
462
  nested_text_column_dropdown: gr.Dropdown(visible=False),
463
  }
464
- info_resp = session.get(
465
- f"{DATASET_VIEWE_API_URL}/info?dataset={dataset}", timeout=20
466
- ).json()
467
- if "error" in info_resp:
468
  return {
469
  subset_dropdown: gr.Dropdown(visible=False),
470
  split_dropdown: gr.Dropdown(visible=False),
471
  text_column_dropdown: gr.Dropdown(label="Text column name"),
472
  nested_text_column_dropdown: gr.Dropdown(visible=False),
473
  }
474
- subsets: list[str] = list(info_resp["dataset_info"])
475
  subset = default_subset if default_subset in subsets else subsets[0]
476
- splits: list[str] = list(info_resp["dataset_info"][subset]["splits"])
477
  split = default_split if default_split in splits else splits[0]
478
- features = info_resp["dataset_info"][subset]["features"]
479
 
480
  def _is_string_feature(feature):
481
  return isinstance(feature, dict) and feature.get("dtype") == "string"
 
5
  import os
6
 
7
  import datamapplot
 
8
  import numpy as np
 
9
 
10
  from dotenv import load_dotenv
11
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
 
26
 
27
  from src.hub import create_space_with_content
28
  from src.templates import REPRESENTATION_PROMPT, SPACE_REPO_CARD_CONTENT
29
+ from src.viewer_api import (
30
+ get_split_rows,
31
+ get_parquet_urls,
32
+ get_docs_from_parquet,
33
+ get_info,
34
+ )
35
+
36
 
37
+ # Load environment variables
38
  load_dotenv()
39
  HF_TOKEN = os.getenv("HF_TOKEN")
40
  assert HF_TOKEN is not None, "You need to set HF_TOKEN in your environment variables"
 
 
 
 
 
 
 
41
  MAX_ROWS = int(os.getenv("MAX_ROWS", "8_000"))
42
  CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "2_000"))
 
43
  DATASETS_TOPICS_ORGANIZATION = os.getenv(
44
  "DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
45
  )
46
  USE_ARROW_STYLE = int(os.getenv("USE_ARROW_STYLE", "0"))
47
+ USE_CUML = int(os.getenv("USE_CUML", "1"))
48
+ USE_LLM_TEXT_GENERATION = int(os.getenv("USE_LLM_TEXT_GENERATION", "1"))
49
 
50
+ # Use cuml lib only if configured
51
  if USE_CUML:
52
  from cuml.manifold import UMAP
53
  from cuml.cluster import HDBSCAN
 
55
  from umap import UMAP
56
  from hdbscan import HDBSCAN
57
 
 
58
 
59
  logging.basicConfig(
60
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
61
  )
62
 
63
  api = HfApi(token=HF_TOKEN)
 
64
  sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
65
 
66
  # Representation model
 
96
  vectorizer_model = CountVectorizer(stop_words="english")
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def calculate_embeddings(docs):
100
  return sentence_model.encode(docs, show_progress_bar=True, batch_size=32)
101
 
 
106
  return n_neighbors, n_components
107
 
108
 
 
109
  def fit_model(docs, embeddings, n_neighbors, n_components):
110
  umap_model = UMAP(
111
  n_neighbors=n_neighbors,
 
216
  reduced_embeddings_array = np.vstack(reduced_embeddings_list)
217
 
218
  topics_info = base_model.get_topic_info()
219
+ all_topics = base_model.topics_
 
 
220
  sub_title = (
221
  f"Data map for the entire dataset ({limit} rows) using the column '{column}'"
222
  if full_processing
223
  else f"Data map for a sample of the dataset (first {limit} rows) using the column '{column}'"
224
  )
 
225
  topic_plot = (
226
  base_model.visualize_document_datamap(
227
  docs=all_docs,
228
+ topics=all_topics,
229
  reduced_embeddings=reduced_embeddings_array,
230
  title=dataset,
231
  sub_title=sub_title,
 
251
  title=dataset,
252
  )
253
  )
 
254
  rows_processed += len(docs)
255
  progress = min(rows_processed / limit, 1.0)
256
  logging.info(f"Progress: {progress} % - {rows_processed} of {limit}")
 
279
  else:
280
  topic_plot.write_image(plot_png)
281
 
282
+ all_topics = base_model.topics_
283
  topic_info = base_model.get_topic_info()
284
 
285
+ topic_names = {row["Topic"]: row["Name"] for _, row in topic_info.iterrows()}
286
  topic_names_array = np.array(
287
  [
288
  topic_names.get(topic, "No Topic").split("_")[1].strip("-")
 
420
  text_column_dropdown: gr.Dropdown(label="Text column name"),
421
  nested_text_column_dropdown: gr.Dropdown(visible=False),
422
  }
423
+ try:
424
+ info_resp = get_info(dataset)
425
+ except Exception:
 
426
  return {
427
  subset_dropdown: gr.Dropdown(visible=False),
428
  split_dropdown: gr.Dropdown(visible=False),
429
  text_column_dropdown: gr.Dropdown(label="Text column name"),
430
  nested_text_column_dropdown: gr.Dropdown(visible=False),
431
  }
432
+ subsets: list[str] = list(info_resp)
433
  subset = default_subset if default_subset in subsets else subsets[0]
434
+ splits: list[str] = list(info_resp[subset]["splits"])
435
  split = default_split if default_split in splits else splits[0]
436
+ features = info_resp[subset]["features"]
437
 
438
  def _is_string_feature(feature):
439
  return isinstance(feature, dict) and feature.get("dtype") == "string"
src/templates.py CHANGED
@@ -5,12 +5,7 @@ You are a helpful, respectful and honest assistant for labeling topics.
5
  """
6
 
7
  EXAMPLE_PROMPT = """
8
- I have a topic that contains the following documents:
9
- - Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
10
- - Meat, but especially beef, is the word food in terms of emissions.
11
- - Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
12
-
13
- The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
14
 
15
  Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
16
 
@@ -19,10 +14,7 @@ Based on the information about the topic above, please create a short label of t
19
 
20
  MAIN_PROMPT = """
21
  [INST]
22
- I have a topic that contains the following documents:
23
- [DOCUMENTS]
24
-
25
- The topic is described by the following keywords: '[KEYWORDS]'.
26
 
27
  Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
28
  [/INST]
@@ -32,7 +24,7 @@ REPRESENTATION_PROMPT = SYSTEM_PROMPT + EXAMPLE_PROMPT + MAIN_PROMPT
32
 
33
  SPACE_REPO_CARD_CONTENT = """
34
  ---
35
- title: {dataset_id} topic modeling
36
  sdk: static
37
  pinned: false
38
  datasets:
 
5
  """
6
 
7
  EXAMPLE_PROMPT = """
8
+ I have a topic that is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
 
 
 
 
 
9
 
10
  Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
11
 
 
14
 
15
  MAIN_PROMPT = """
16
  [INST]
17
+ I have a topic that is described by the following keywords: '[KEYWORDS]'.
 
 
 
18
 
19
  Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
20
  [/INST]
 
24
 
25
  SPACE_REPO_CARD_CONTENT = """
26
  ---
27
+ title: {dataset_id}
28
  sdk: static
29
  pinned: false
30
  datasets:
src/viewer_api.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import duckdb
3
+
4
+ DATASET_VIEWER_API_URL = "https://datasets-server.huggingface.co/"
5
+ session = requests.Session()
6
+
7
+
8
+ def fetch_json(url, params=None, timeout=20):
9
+ response = session.get(url, params=params, timeout=timeout)
10
+ response.raise_for_status()
11
+ data = response.json()
12
+ if "error" in data:
13
+ raise Exception(f"Error fetching data: {data['error']}")
14
+ return data
15
+
16
+
17
+ def get_split_rows(dataset, config, split):
18
+ url = f"{DATASET_VIEWER_API_URL}/size"
19
+ params = {"dataset": dataset, "config": config}
20
+ config_size = fetch_json(url, params)
21
+
22
+ split_size = next(
23
+ (s for s in config_size["size"]["splits"] if s["split"] == split), None
24
+ )
25
+ if split_size is None:
26
+ raise Exception(f"Error fetching split {split} in config {config}")
27
+
28
+ return split_size["num_rows"]
29
+
30
+
31
+ def get_parquet_urls(dataset, config, split):
32
+ url = f"{DATASET_VIEWER_API_URL}/parquet"
33
+ params = {"dataset": dataset, "config": config, "split": split}
34
+ parquet_files = fetch_json(url, params)
35
+
36
+ parquet_urls = [file["url"] for file in parquet_files["parquet_files"]]
37
+ return ",".join(f"'{url}'" for url in parquet_urls)
38
+
39
+
40
+ def get_docs_from_parquet(parquet_urls, column, offset, limit):
41
+ sql_query = f"SELECT {column} FROM read_parquet([{parquet_urls}]) LIMIT {limit} OFFSET {offset};"
42
+ df = duckdb.sql(sql_query).to_df()
43
+ return df[column].tolist()
44
+
45
+
46
+ def get_info(dataset):
47
+ url = f"{DATASET_VIEWER_API_URL}/info"
48
+ params = {"dataset": dataset}
49
+ info_resp = fetch_json(url, params)
50
+
51
+ return info_resp["dataset_info"]