abrah926 commited on
Commit
a124d51
Β·
verified Β·
1 Parent(s): 707dafd

batch embedding

Browse files
Files changed (1) hide show
  1. embeddings.py +35 -79
embeddings.py CHANGED
@@ -4,43 +4,10 @@ import faiss
4
  import torch
5
  import numpy as np
6
  import os
7
- import json
8
 
9
  def log(message):
10
  print(f"βœ… {message}")
11
 
12
-
13
-
14
- # βœ… Ensure data folder exists
15
- DATA_DIR = "data"
16
- os.makedirs(DATA_DIR, exist_ok=True)
17
-
18
- # βœ… List of datasets
19
- datasets_list = {
20
- "sales": "goendalf666/sales-conversations",
21
- "blended": "blended_skill_talk",
22
- "dialog": "daily_dialog",
23
- "multiwoz": "multi_woz_v22",
24
- }
25
-
26
- def save_dataset_to_file(dataset_name, dataset):
27
- """Save dataset to a local JSON file."""
28
- file_path = os.path.join(DATA_DIR, f"{dataset_name}.json")
29
-
30
- with open(file_path, "w") as f:
31
- json.dump(dataset["train"].to_dict(), f)
32
-
33
- print(f"βœ… Saved {dataset_name} to {file_path}")
34
-
35
- # βœ… Load & Save all datasets
36
- for name, dataset_id in datasets_list.items():
37
- dataset = load_dataset(dataset_id, split="train")
38
- save_dataset_to_file(name, dataset)
39
-
40
- print("βœ… All datasets saved locally!")
41
-
42
-
43
-
44
  # βœ… Load datasets
45
  datasets = {
46
  "sales": load_dataset("goendalf666/sales-conversations"),
@@ -49,80 +16,69 @@ datasets = {
49
  "multiwoz": load_dataset("multi_woz_v22"),
50
  }
51
 
52
- # βœ… Load MiniLM model and tokenizer
53
- model_name = "sentence-transformers/all-MiniLM-L6-v2" # Model for embeddings
54
  tokenizer = AutoTokenizer.from_pretrained(model_name)
55
  model = AutoModel.from_pretrained(model_name)
56
 
57
  def embed_text(texts):
 
58
  inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
59
  with torch.no_grad():
60
  embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
61
  return embeddings
62
 
63
-
64
- # βœ… Extract and embed the datasets
65
- def create_embeddings(dataset_name, dataset):
66
  print(f"πŸ“₯ Creating embeddings for {dataset_name}...")
67
 
68
  if dataset_name == "goendalf666/sales-conversations":
69
  texts = [" ".join(row.values()) for row in dataset["train"]]
70
-
71
- elif dataset_name == "AlekseyKorshuk/persona-chat":
72
- texts = [" ".join(utterance["candidates"]) for utterance in dataset["train"]["utterances"]]
73
-
74
  elif dataset_name == "blended_skill_talk":
75
  texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset["train"]]
76
-
77
  elif dataset_name == "daily_dialog":
78
  texts = [" ".join(row["dialog"]) for row in dataset["train"]]
79
-
80
  elif dataset_name == "multi_woz_v22":
81
  texts = [" ".join(row["turns"]["utterance"]) for row in dataset["train"]]
82
-
83
  else:
84
  print(f"⚠️ Warning: Dataset {dataset_name} not handled properly!")
85
  texts = []
86
 
87
- # βœ… Verify dataset extraction
88
- if len(texts) == 0:
89
- print(f"❌ ERROR: No text extracted from {dataset_name}! Check dataset structure.")
90
- else:
91
- print(f"βœ… Extracted {len(texts)} texts from {dataset_name}. Sample:\n{texts[:3]}")
92
-
93
- return texts
94
 
95
- # βœ… Embed and store in FAISS
96
- for name, dataset in datasets.items():
97
- texts = create_embeddings(name, dataset)
 
 
 
 
 
 
98
 
99
- if len(texts) > 0: # βœ… Only embed if texts exist
100
- embeddings = embed_text(texts)
101
- print(f"βœ… Generated embeddings shape: {embeddings.shape}")
102
 
103
- index = save_embeddings_to_faiss(embeddings)
104
- print(f"βœ… Embeddings for {name} saved to FAISS.")
 
 
 
 
 
 
 
105
  else:
106
- print(f"⚠️ Skipping embedding for {name} (No valid texts).")
 
107
 
 
 
108
 
109
-
110
- # βœ… Save embeddings to a database
111
- def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):
112
- print("Saving embeddings to FAISS...")
113
- index = faiss.IndexFlatL2(embeddings.shape[1]) # Assuming 512-dimensional embeddings
114
- index.add(np.array(embeddings).astype(np.float32))
115
- faiss.write_index(index, index_name) # Save FAISS index to file
116
- return index
117
-
118
- # βœ… Create embeddings for all datasets
119
  for name, dataset in datasets.items():
120
- embeddings = create_embeddings(name, dataset)
121
- index = save_embeddings_to_faiss(embeddings)
122
- print(f"Embeddings for {name} saved to FAISS.")
123
-
124
-
125
- # βœ… Check FAISS index after saving
126
- index = faiss.read_index("my_embeddings") # Load the index
127
- print(f"πŸ“Š FAISS index contains {index.ntotal} vectors.") # Check how many embeddings were stored
128
 
 
4
  import torch
5
  import numpy as np
6
  import os
 
7
 
8
  def log(message):
9
  print(f"βœ… {message}")
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # βœ… Load datasets
12
  datasets = {
13
  "sales": load_dataset("goendalf666/sales-conversations"),
 
16
  "multiwoz": load_dataset("multi_woz_v22"),
17
  }
18
 
19
+ # βœ… Load MiniLM model for embeddings
20
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
  model = AutoModel.from_pretrained(model_name)
23
 
24
  def embed_text(texts):
25
+ """Generate embeddings for a batch of texts."""
26
  inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
27
  with torch.no_grad():
28
  embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
29
  return embeddings
30
 
31
+ # βœ… Batch processing function
32
+ def create_embeddings(dataset_name, dataset, batch_size=100):
 
33
  print(f"πŸ“₯ Creating embeddings for {dataset_name}...")
34
 
35
  if dataset_name == "goendalf666/sales-conversations":
36
  texts = [" ".join(row.values()) for row in dataset["train"]]
 
 
 
 
37
  elif dataset_name == "blended_skill_talk":
38
  texts = [" ".join(row["free_messages"] + row["guided_messages"]) for row in dataset["train"]]
 
39
  elif dataset_name == "daily_dialog":
40
  texts = [" ".join(row["dialog"]) for row in dataset["train"]]
 
41
  elif dataset_name == "multi_woz_v22":
42
  texts = [" ".join(row["turns"]["utterance"]) for row in dataset["train"]]
 
43
  else:
44
  print(f"⚠️ Warning: Dataset {dataset_name} not handled properly!")
45
  texts = []
46
 
47
+ log(f"βœ… Extracted {len(texts)} texts from {dataset_name}.")
 
 
 
 
 
 
48
 
49
+ # Process in batches
50
+ all_embeddings = []
51
+ for i in range(0, len(texts), batch_size):
52
+ batch = texts[i : i + batch_size]
53
+ batch_embeddings = embed_text(batch)
54
+ all_embeddings.append(batch_embeddings)
55
+
56
+ # βœ… Log progress
57
+ log(f"πŸš€ Processed {i + len(batch)}/{len(texts)} embeddings for {dataset_name}...")
58
 
59
+ # Convert list of numpy arrays to a single numpy array
60
+ all_embeddings = np.vstack(all_embeddings)
61
+ return all_embeddings
62
 
63
+ # βœ… Save embeddings to FAISS with unique filename
64
+ def save_embeddings_to_faiss(embeddings, index_name="my_embeddings"):
65
+ index_file = f"{index_name}.faiss"
66
+
67
+ # βœ… Check if previous FAISS index exists, append if needed
68
+ if os.path.exists(index_file):
69
+ log("πŸ”„ Loading existing FAISS index to append...")
70
+ index = faiss.read_index(index_file)
71
+ index.add(np.array(embeddings).astype(np.float32))
72
  else:
73
+ index = faiss.IndexFlatL2(embeddings.shape[1])
74
+ index.add(np.array(embeddings).astype(np.float32))
75
 
76
+ faiss.write_index(index, index_file) # βœ… Save FAISS index
77
+ log(f"βœ… Saved FAISS index: {index_file}")
78
 
79
+ # βœ… Run embeddings process
 
 
 
 
 
 
 
 
 
80
  for name, dataset in datasets.items():
81
+ embeddings = create_embeddings(name, dataset, batch_size=100)
82
+ save_embeddings_to_faiss(embeddings, index_name=name)
83
+ log(f"βœ… Embeddings for {name} saved to FAISS.")
 
 
 
 
 
84