asoria HF staff commited on
Commit
bf92466
·
1 Parent(s): a5eff40

Replace model with inference client + llama3

Browse files
Files changed (3) hide show
  1. app.py +34 -49
  2. requirements.txt +0 -1
  3. src/templates.py +10 -0
app.py CHANGED
@@ -11,21 +11,13 @@ from dotenv import load_dotenv
11
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
12
  from bertopic import BERTopic
13
  from bertopic.representation import KeyBERTInspired
14
- from bertopic.representation import TextGeneration
15
 
16
- from huggingface_hub import HfApi
17
  from sklearn.feature_extraction.text import CountVectorizer
18
  from sentence_transformers import SentenceTransformer
19
- from torch import cuda, bfloat16
20
- from transformers import (
21
- BitsAndBytesConfig,
22
- AutoTokenizer,
23
- AutoModelForCausalLM,
24
- pipeline,
25
- )
26
 
27
  from src.hub import create_space_with_content
28
- from src.templates import REPRESENTATION_PROMPT, SPACE_REPO_CARD_CONTENT
29
  from src.viewer_api import (
30
  get_split_rows,
31
  get_parquet_urls,
@@ -60,35 +52,13 @@ logging.basicConfig(
60
 
61
  api = HfApi(token=HF_TOKEN)
62
 
63
- bnb_config = BitsAndBytesConfig(
64
- load_in_4bit=True,
65
- bnb_4bit_quant_type="nf4",
66
- bnb_4bit_use_double_quant=True,
67
- bnb_4bit_compute_dtype=bfloat16,
68
- )
69
-
70
- model_id = "meta-llama/Llama-2-7b-chat-hf"
71
- tokenizer = AutoTokenizer.from_pretrained(model_id)
72
- model = AutoModelForCausalLM.from_pretrained(
73
- model_id,
74
- trust_remote_code=True,
75
- quantization_config=bnb_config,
76
- device_map="auto",
77
- )
78
- model.eval()
79
- generator = pipeline(
80
- model=model,
81
- tokenizer=tokenizer,
82
- task="text-generation",
83
- temperature=0.1,
84
- max_new_tokens=500,
85
- repetition_penalty=1.1,
86
- )
87
-
88
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
89
  vectorizer_model = CountVectorizer(stop_words="english")
90
  representation_model = KeyBERTInspired()
91
 
 
 
92
 
93
  def calculate_embeddings(docs):
94
  return embedding_model.encode(docs, show_progress_bar=True, batch_size=32)
@@ -294,13 +264,6 @@ def generate_topics(dataset, config, split, column, plot_type):
294
  "",
295
  )
296
 
297
- dataset_clear_name = dataset.replace("/", "-")
298
- plot_png = f"{dataset_clear_name}-{plot_type.lower()}.png"
299
- if plot_type == "DataMapPlot":
300
- topic_plot.savefig(plot_png, format="png", dpi=300)
301
- else:
302
- topic_plot.write_image(plot_png)
303
-
304
  all_topics = base_model.topics_
305
  topics_info = base_model.get_topic_info()
306
 
@@ -309,13 +272,27 @@ def generate_topics(dataset, config, split, column, plot_type):
309
  logging.info(
310
  f"Processing topic: {row['Topic']} - Representation: {row['Representation']}"
311
  )
312
- prompt = f"{REPRESENTATION_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}"
313
- logging.info(prompt)
314
- topic_description = generator(prompt)
315
- logging.info(topic_description)
316
- new_topics_by_text_generation[row["Topic"]] = topic_description[0][
317
- "generated_text"
318
- ].replace(prompt, "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  base_model.set_topic_labels(new_topics_by_text_generation)
320
 
321
  topics_info = base_model.get_topic_info()
@@ -350,6 +327,14 @@ def generate_topics(dataset, config, split, column, plot_type):
350
  title="",
351
  )
352
  )
 
 
 
 
 
 
 
 
353
  custom_labels = base_model.custom_labels_
354
  topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics]
355
  yield (
 
11
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
12
  from bertopic import BERTopic
13
  from bertopic.representation import KeyBERTInspired
 
14
 
15
+ from huggingface_hub import HfApi, InferenceClient
16
  from sklearn.feature_extraction.text import CountVectorizer
17
  from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
18
 
19
  from src.hub import create_space_with_content
20
+ from src.templates import LLAMA_3_8B_PROMPT, SPACE_REPO_CARD_CONTENT
21
  from src.viewer_api import (
22
  get_split_rows,
23
  get_parquet_urls,
 
52
 
53
  api = HfApi(token=HF_TOKEN)
54
 
55
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
57
  vectorizer_model = CountVectorizer(stop_words="english")
58
  representation_model = KeyBERTInspired()
59
 
60
+ inference_client = InferenceClient(model_id)
61
+
62
 
63
  def calculate_embeddings(docs):
64
  return embedding_model.encode(docs, show_progress_bar=True, batch_size=32)
 
264
  "",
265
  )
266
 
 
 
 
 
 
 
 
267
  all_topics = base_model.topics_
268
  topics_info = base_model.get_topic_info()
269
 
 
272
  logging.info(
273
  f"Processing topic: {row['Topic']} - Representation: {row['Representation']}"
274
  )
275
+ prompt = f"{LLAMA_3_8B_PROMPT.replace('[KEYWORDS]', ','.join(row['Representation']))}"
276
+ prompt_messages = [
277
+ {
278
+ "role": "system",
279
+ "content": "You are a helpful, respectful and honest assistant for labeling topics.",
280
+ },
281
+ {"role": "user", "content": prompt},
282
+ ]
283
+ output = inference_client.chat_completion(
284
+ messages=prompt_messages,
285
+ stream=False,
286
+ max_tokens=500,
287
+ top_p=0.8,
288
+ seed=42,
289
+ )
290
+ inference_response = output.choices[0].message.content
291
+ logging.info("Inference response:")
292
+ logging.info(inference_response)
293
+ new_topics_by_text_generation[row["Topic"]] = inference_response.replace(
294
+ "Topic=", ""
295
+ ).strip()
296
  base_model.set_topic_labels(new_topics_by_text_generation)
297
 
298
  topics_info = base_model.get_topic_info()
 
327
  title="",
328
  )
329
  )
330
+
331
+ dataset_clear_name = dataset.replace("/", "-")
332
+ plot_png = f"{dataset_clear_name}-{plot_type.lower()}.png"
333
+ if plot_type == "DataMapPlot":
334
+ topic_plot.savefig(plot_png, format="png", dpi=300)
335
+ else:
336
+ topic_plot.write_image(plot_png)
337
+
338
  custom_labels = base_model.custom_labels_
339
  topic_names_array = [custom_labels[doc_topic + 1] for doc_topic in all_topics]
340
  yield (
requirements.txt CHANGED
@@ -15,4 +15,3 @@ pandas
15
  numpy
16
  python-dotenv
17
  kaleido
18
- transformers
 
15
  numpy
16
  python-dotenv
17
  kaleido
 
src/templates.py CHANGED
@@ -22,6 +22,16 @@ Based on the information about the topic above, please create a short label of t
22
 
23
  REPRESENTATION_PROMPT = f"{SYSTEM_PROMPT}{EXAMPLE_PROMPT}{MAIN_PROMPT}"
24
 
 
 
 
 
 
 
 
 
 
 
25
  SPACE_REPO_CARD_CONTENT = """
26
  ---
27
  title: {dataset_id}
 
22
 
23
  REPRESENTATION_PROMPT = f"{SYSTEM_PROMPT}{EXAMPLE_PROMPT}{MAIN_PROMPT}"
24
 
25
+ LLAMA_3_8B_PROMPT = """
26
+ Example:
27
+ I have a topic that is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.
28
+ Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
29
+ Topic=Environmental impacts of eating meat
30
+ Instruction:
31
+ I have a topic that is described by the following keywords: '[KEYWORDS]'.
32
+ Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
33
+ """
34
+
35
  SPACE_REPO_CARD_CONTENT = """
36
  ---
37
  title: {dataset_id}