davanstrien HF Staff commited on
Commit
a302e07
·
1 Parent(s): f8148b8

add device detection for model inference and improve dataset collection logging

Browse files
Files changed (1) hide show
  1. main.py +21 -6
main.py CHANGED
@@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
12
  import polars as pl
13
  from huggingface_hub import HfApi
14
  from transformers import AutoTokenizer
 
15
 
16
  # Configuration constants
17
  MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
@@ -19,6 +20,13 @@ EMBEDDING_MODEL = "nomic-ai/modernbert-embed-base"
19
  BATCH_SIZE = 1000
20
  CACHE_TTL = "60"
21
 
 
 
 
 
 
 
 
22
  hf_api = HfApi()
23
 
24
 
@@ -72,8 +80,9 @@ app.add_middleware(
72
 
73
  # Define the embedding function at module level
74
  def get_embedding_function():
 
75
  return embedding_functions.SentenceTransformerEmbeddingFunction(
76
- model_name="nomic-ai/modernbert-embed-base"
77
  )
78
 
79
 
@@ -95,7 +104,7 @@ def setup_database():
95
  metadata={"hnsw:space": "cosine"},
96
  )
97
 
98
- # TODO incremental updates
99
  df = pl.scan_parquet(
100
  "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet"
101
  )
@@ -103,14 +112,21 @@ def setup_database():
103
  pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
104
  )
105
  row_count = df.select(pl.len()).collect().item()
106
- logger.info(f"Row count of new data: {row_count}")
107
- if dataset_collection.count() < row_count:
 
 
 
 
 
 
 
 
108
  # Load parquet files and upsert into ChromaDB
109
  df = df.select(
110
  ["datasetId", "summary", "likes", "downloads", "last_modified"]
111
  )
112
  df = df.collect()
113
- BATCH_SIZE = 1000
114
  total_rows = len(df)
115
 
116
  for i in range(0, total_rows, BATCH_SIZE):
@@ -148,7 +164,6 @@ def setup_database():
148
  ["modelId", "summary", "likes", "downloads", "last_modified"]
149
  )
150
  model_df = model_df.collect()
151
- BATCH_SIZE = 1000
152
  total_rows = len(model_df)
153
 
154
  for i in range(0, total_rows, BATCH_SIZE):
 
12
  import polars as pl
13
  from huggingface_hub import HfApi
14
  from transformers import AutoTokenizer
15
+ import torch
16
 
17
  # Configuration constants
18
  MODEL_NAME = "davanstrien/SmolLM2-360M-tldr-sft-2025-02-12_15-13"
 
20
  BATCH_SIZE = 1000
21
  CACHE_TTL = "60"
22
 
23
+ if torch.cuda.is_available():
24
+ DEVICE = "cuda"
25
+ elif torch.backends.mps.is_available():
26
+ DEVICE = "mps"
27
+ else:
28
+ DEVICE = "cpu"
29
+
30
  hf_api = HfApi()
31
 
32
 
 
80
 
81
  # Define the embedding function at module level
82
  def get_embedding_function():
83
+ logger.info(f"Using device: {DEVICE}")
84
  return embedding_functions.SentenceTransformerEmbeddingFunction(
85
+ model_name="nomic-ai/modernbert-embed-base", device=DEVICE
86
  )
87
 
88
 
 
104
  metadata={"hnsw:space": "cosine"},
105
  )
106
 
107
+ # Load dataset data
108
  df = pl.scan_parquet(
109
  "hf://datasets/davanstrien/datasets_with_metadata_and_summaries/data/train-*.parquet"
110
  )
 
112
  pl.col("datasetId").str.contains_any(["open-llm-leaderboard-old/"]).not_()
113
  )
114
  row_count = df.select(pl.len()).collect().item()
115
+ logger.info(f"Row count of dataset data: {row_count}")
116
+
117
+ # Check if we need to update the collection
118
+ current_count = dataset_collection.count()
119
+ logger.info(f"Current dataset collection count: {current_count}")
120
+
121
+ if current_count < row_count:
122
+ logger.info(
123
+ f"Updating dataset collection with {row_count - current_count} new records"
124
+ )
125
  # Load parquet files and upsert into ChromaDB
126
  df = df.select(
127
  ["datasetId", "summary", "likes", "downloads", "last_modified"]
128
  )
129
  df = df.collect()
 
130
  total_rows = len(df)
131
 
132
  for i in range(0, total_rows, BATCH_SIZE):
 
164
  ["modelId", "summary", "likes", "downloads", "last_modified"]
165
  )
166
  model_df = model_df.collect()
 
167
  total_rows = len(model_df)
168
 
169
  for i in range(0, total_rows, BATCH_SIZE):