asoria HF staff commited on
Commit
95d1f22
·
1 Parent(s): cefd61d

Parameterize behavior

Browse files
Files changed (3) hide show
  1. app.py +54 -56
  2. requirements.txt +2 -2
  3. prompts.py → templates.py +11 -0
app.py CHANGED
@@ -19,7 +19,7 @@ from bertopic.representation import TextGeneration
19
  from huggingface_hub import HfApi, SpaceCard
20
  from sklearn.feature_extraction.text import CountVectorizer
21
  from sentence_transformers import SentenceTransformer
22
- from prompts import REPRESENTATION_PROMPT
23
  from torch import cuda, bfloat16
24
  from transformers import (
25
  BitsAndBytesConfig,
@@ -27,11 +27,6 @@ from transformers import (
27
  AutoModelForCausalLM,
28
  pipeline,
29
  )
30
- # from cuml.manifold import UMAP
31
- # from cuml.cluster import HDBSCAN
32
-
33
- from umap import UMAP
34
- from hdbscan import HDBSCAN
35
 
36
  """
37
  TODOs:
@@ -51,52 +46,68 @@ assert (
51
  EXPORTS_REPOSITORY is not None
52
  ), "You need to set EXPORTS_REPOSITORY in your environment variables"
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  logging.basicConfig(
55
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
56
  )
57
 
58
- MAX_ROWS = 50_000
59
- CHUNK_SIZE = 10_000
60
-
61
  api = HfApi(token=HF_TOKEN)
62
-
63
  session = requests.Session()
64
  sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
65
 
66
  # Representation model
67
- bnb_config = BitsAndBytesConfig(
68
- load_in_4bit=True,
69
- bnb_4bit_quant_type="nf4",
70
- bnb_4bit_use_double_quant=True,
71
- bnb_4bit_compute_dtype=bfloat16,
72
- )
 
73
 
74
- model_id = "meta-llama/Llama-2-7b-chat-hf"
75
- tokenizer = AutoTokenizer.from_pretrained(model_id)
76
- model = AutoModelForCausalLM.from_pretrained(
77
- model_id,
78
- trust_remote_code=True,
79
- quantization_config=bnb_config,
80
- device_map="auto",
81
- )
82
- model.eval()
83
- generator = pipeline(
84
- model=model,
85
- tokenizer=tokenizer,
86
- task="text-generation",
87
- temperature=0.1,
88
- max_new_tokens=500,
89
- repetition_penalty=1.1,
90
- )
91
- representation_model = TextGeneration(generator, prompt=REPRESENTATION_PROMPT)
92
- # End of representation model
 
93
 
94
  vectorizer_model = CountVectorizer(stop_words="english")
95
 
96
 
97
  def get_split_rows(dataset, config, split):
98
  config_size = session.get(
99
- f"https://datasets-server.huggingface.co/size?dataset={dataset}&config={config}",
100
  timeout=20,
101
  ).json()
102
  if "error" in config_size:
@@ -112,7 +123,7 @@ def get_split_rows(dataset, config, split):
112
 
113
  def get_parquet_urls(dataset, config, split):
114
  parquet_files = session.get(
115
- f"https://datasets-server.huggingface.co/parquet?dataset={dataset}&config={config}&split={split}",
116
  timeout=20,
117
  ).json()
118
  if "error" in parquet_files:
@@ -125,7 +136,6 @@ def get_parquet_urls(dataset, config, split):
125
  def get_docs_from_parquet(parquet_urls, column, offset, limit):
126
  SQL_QUERY = f"SELECT {column} FROM read_parquet([{parquet_urls}]) LIMIT {limit} OFFSET {offset};"
127
  df = duckdb.sql(SQL_QUERY).to_df()
128
- logging.debug(f"Dataframe: {df.head(5)}")
129
  return df[column].tolist()
130
 
131
 
@@ -200,8 +210,7 @@ def _push_to_hub(
200
 
201
 
202
  def create_space_with_content(dataset_id, html_file_path):
203
- # TODO: Parameterize organization name
204
- repo_id = f"datasets-topics/{dataset_id.replace('/', '-')}"
205
  logging.info(f"Creating space with content: {repo_id} on file {html_file_path}")
206
  api.create_repo(
207
  repo_id=repo_id,
@@ -211,16 +220,6 @@ def create_space_with_content(dataset_id, html_file_path):
211
  token=HF_TOKEN,
212
  space_sdk="static",
213
  )
214
- SPACE_REPO_CARD_CONTENT = """
215
- ---
216
- title: {dataset_id} topic modeling
217
- sdk: static
218
- pinned: false
219
- datasets:
220
- - {dataset_id}
221
- ---
222
-
223
- """
224
 
225
  SpaceCard(
226
  content=SPACE_REPO_CARD_CONTENT.format(dataset_id=dataset_id)
@@ -233,14 +232,14 @@ datasets:
233
  repo_id=repo_id,
234
  token=HF_TOKEN,
235
  )
236
- logging.info(f"Space created done")
237
  return repo_id
238
 
239
 
240
  @spaces.GPU(duration=120)
241
  def generate_topics(dataset, config, split, column, nested_column, plot_type):
242
  logging.info(
243
- f"Generating topics for {dataset} with config {config} {split} {column} {nested_column}"
244
  )
245
 
246
  parquet_urls = get_parquet_urls(dataset, config, split)
@@ -326,8 +325,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
326
  "linewidth": 0,
327
  "fc": "#33333377",
328
  },
329
- # TODO: Make it configurable in UI
330
- dynamic_label_size=False,
331
  # label_wrap_width=12,
332
  # label_over_points=True,
333
  # dynamic_label_size=True,
@@ -395,7 +393,7 @@ def generate_topics(dataset, config, split, column, nested_column, plot_type):
395
  # TODO: Export data to .arrow and also serve it
396
  inline_data=True,
397
  # offline_data_prefix=dataset_clear_name,
398
- initial_zoom_fraction=0.9,
399
  )
400
  html_content = str(interactive_plot)
401
  html_file_path = f"{dataset_clear_name}.html"
@@ -503,7 +501,7 @@ with gr.Blocks() as demo:
503
  nested_text_column_dropdown: gr.Dropdown(visible=False),
504
  }
505
  info_resp = session.get(
506
- f"https://datasets-server.huggingface.co/info?dataset={dataset}", timeout=20
507
  ).json()
508
  if "error" in info_resp:
509
  return {
 
19
  from huggingface_hub import HfApi, SpaceCard
20
  from sklearn.feature_extraction.text import CountVectorizer
21
  from sentence_transformers import SentenceTransformer
22
+ from templates import REPRESENTATION_PROMPT, SPACE_REPO_CARD_CONTENT
23
  from torch import cuda, bfloat16
24
  from transformers import (
25
  BitsAndBytesConfig,
 
27
  AutoModelForCausalLM,
28
  pipeline,
29
  )
 
 
 
 
 
30
 
31
  """
32
  TODOs:
 
46
  EXPORTS_REPOSITORY is not None
47
  ), "You need to set EXPORTS_REPOSITORY in your environment variables"
48
 
49
+ MAX_ROWS = int(os.getenv("MAX_ROWS", "10_000"))
50
+ CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "2_000"))
51
+ DATASET_VIEWE_API_URL = "https://datasets-server.huggingface.co/"
52
+ DATASETS_TOPICS_ORGANIZATION = os.getenv(
53
+ "DATASETS_TOPICS_ORGANIZATION", "datasets-topics"
54
+ )
55
+ USE_ARROW_STYLE = int(os.getenv("USE_ARROW_STYLE", "0"))
56
+ USE_CUML = int(os.getenv("USE_CUML", "0"))
57
+
58
+ if USE_CUML:
59
+ from cuml.manifold import UMAP
60
+ from cuml.cluster import HDBSCAN
61
+ else:
62
+ from umap import UMAP
63
+ from hdbscan import HDBSCAN
64
+
65
+ USE_LLM_TEXT_GENERATION = int(os.getenv("USE_LLM_TEXT_GENERATION", "1"))
66
+
67
  logging.basicConfig(
68
  level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
69
  )
70
 
 
 
 
71
  api = HfApi(token=HF_TOKEN)
 
72
  session = requests.Session()
73
  sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
74
 
75
  # Representation model
76
+ if USE_LLM_TEXT_GENERATION:
77
+ bnb_config = BitsAndBytesConfig(
78
+ load_in_4bit=True,
79
+ bnb_4bit_quant_type="nf4",
80
+ bnb_4bit_use_double_quant=True,
81
+ bnb_4bit_compute_dtype=bfloat16,
82
+ )
83
 
84
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
85
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
86
+ model = AutoModelForCausalLM.from_pretrained(
87
+ model_id,
88
+ trust_remote_code=True,
89
+ quantization_config=bnb_config,
90
+ device_map="auto",
91
+ )
92
+ model.eval()
93
+ generator = pipeline(
94
+ model=model,
95
+ tokenizer=tokenizer,
96
+ task="text-generation",
97
+ temperature=0.1,
98
+ max_new_tokens=500,
99
+ repetition_penalty=1.1,
100
+ )
101
+ representation_model = TextGeneration(generator, prompt=REPRESENTATION_PROMPT)
102
+ else:
103
+ representation_model = KeyBERTInspired()
104
 
105
  vectorizer_model = CountVectorizer(stop_words="english")
106
 
107
 
108
  def get_split_rows(dataset, config, split):
109
  config_size = session.get(
110
+ f"{DATASET_VIEWE_API_URL}/size?dataset={dataset}&config={config}",
111
  timeout=20,
112
  ).json()
113
  if "error" in config_size:
 
123
 
124
  def get_parquet_urls(dataset, config, split):
125
  parquet_files = session.get(
126
+ f"{DATASET_VIEWE_API_URL}/parquet?dataset={dataset}&config={config}&split={split}",
127
  timeout=20,
128
  ).json()
129
  if "error" in parquet_files:
 
136
  def get_docs_from_parquet(parquet_urls, column, offset, limit):
137
  SQL_QUERY = f"SELECT {column} FROM read_parquet([{parquet_urls}]) LIMIT {limit} OFFSET {offset};"
138
  df = duckdb.sql(SQL_QUERY).to_df()
 
139
  return df[column].tolist()
140
 
141
 
 
210
 
211
 
212
  def create_space_with_content(dataset_id, html_file_path):
213
+ repo_id = f"{DATASETS_TOPICS_ORGANIZATION}/{dataset_id.replace('/', '-')}"
 
214
  logging.info(f"Creating space with content: {repo_id} on file {html_file_path}")
215
  api.create_repo(
216
  repo_id=repo_id,
 
220
  token=HF_TOKEN,
221
  space_sdk="static",
222
  )
 
 
 
 
 
 
 
 
 
 
223
 
224
  SpaceCard(
225
  content=SPACE_REPO_CARD_CONTENT.format(dataset_id=dataset_id)
 
232
  repo_id=repo_id,
233
  token=HF_TOKEN,
234
  )
235
+ logging.info(f"Space creation done")
236
  return repo_id
237
 
238
 
239
  @spaces.GPU(duration=120)
240
  def generate_topics(dataset, config, split, column, nested_column, plot_type):
241
  logging.info(
242
+ f"Generating topics for {dataset=} {config=} {split=} {column=} {nested_column=} {plot_type=}"
243
  )
244
 
245
  parquet_urls = get_parquet_urls(dataset, config, split)
 
325
  "linewidth": 0,
326
  "fc": "#33333377",
327
  },
328
+ dynamic_label_size=USE_ARROW_STYLE,
 
329
  # label_wrap_width=12,
330
  # label_over_points=True,
331
  # dynamic_label_size=True,
 
393
  # TODO: Export data to .arrow and also serve it
394
  inline_data=True,
395
  # offline_data_prefix=dataset_clear_name,
396
+ initial_zoom_fraction=0.8,
397
  )
398
  html_content = str(interactive_plot)
399
  html_file_path = f"{dataset_clear_name}.html"
 
501
  nested_text_column_dropdown: gr.Dropdown(visible=False),
502
  }
503
  info_resp = session.get(
504
+ f"{DATASET_VIEWE_API_URL}/info?dataset={dataset}", timeout=20
505
  ).json()
506
  if "error" in info_resp:
507
  return {
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- # --extra-index-url https://pypi.nvidia.com
2
- # cuml-cu11
3
  spaces
4
  gradio
5
  torch
 
1
+ --extra-index-url https://pypi.nvidia.com
2
+ cuml-cu11
3
  spaces
4
  gradio
5
  torch
prompts.py → templates.py RENAMED
@@ -29,3 +29,14 @@ Based on the information about the topic above, please create a short label of t
29
  """
30
 
31
  REPRESENTATION_PROMPT = SYSTEM_PROMPT + EXAMPLE_PROMPT + MAIN_PROMPT
 
 
 
 
 
 
 
 
 
 
 
 
29
  """
30
 
31
  REPRESENTATION_PROMPT = SYSTEM_PROMPT + EXAMPLE_PROMPT + MAIN_PROMPT
32
+
33
+ SPACE_REPO_CARD_CONTENT = """
34
+ ---
35
+ title: {dataset_id} topic modeling
36
+ sdk: static
37
+ pinned: false
38
+ datasets:
39
+ - {dataset_id}
40
+ ---
41
+
42
+ """