thechaiexperiment commited on
Commit
2d991dc
·
1 Parent(s): 7f09c10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -74
app.py CHANGED
@@ -20,6 +20,7 @@ import pandas as pd
20
  import subprocess
21
  from typing import Dict, Optional
22
  import codecs
 
23
 
24
  try:
25
  subprocess.run(['git', 'lfs', 'pull'], check=True)
@@ -93,86 +94,32 @@ def load_models():
93
  return False
94
 
95
 
96
- class LFSEmbeddingsUnpickler(pickle.Unpickler):
97
- def persistent_load(self, pid):
98
- # Ensure persistent ID is ASCII string
99
- if isinstance(pid, bytes):
100
- return pid.decode('ascii')
101
- return str(pid)
102
-
103
- def load_embeddings(embeddings_path: str = 'embeddings.pkl') -> Optional[Dict[str, np.ndarray]]:
104
- """
105
- Load embeddings from a pickle file with support for Git LFS and protocol 0 requirements.
106
-
107
- Args:
108
- embeddings_path (str): Path to the pickle file containing embeddings
109
-
110
- Returns:
111
- Optional[Dict[str, np.ndarray]]: Dictionary of embeddings or None if loading fails
112
- """
113
- if not os.path.exists(embeddings_path):
114
- print(f"Error: {embeddings_path} not found")
115
- return None
116
-
117
  try:
118
- # Open file in binary mode with buffering
119
- with open(embeddings_path, 'rb', buffering=1024*1024) as f:
120
- # Check if it's a Git LFS pointer file
121
- first_line = f.peek(100)[:100].decode('utf-8', errors='ignore')
122
- if 'version https://git-lfs.github.com/spec/' in first_line:
123
- print("Warning: This appears to be a Git LFS pointer file.")
124
- print("Please ensure you've properly downloaded the actual embeddings file using Git LFS")
125
- return None
126
-
127
- # Use custom unpickler with ASCII string handling
128
- unpickler = LFSEmbeddingsUnpickler(f)
129
-
130
- # Set encoding for protocol 0 compatibility
131
- if hasattr(unpickler, 'encoding'):
132
- unpickler.encoding = 'ascii'
133
-
134
- try:
135
- embeddings = unpickler.load()
136
- except UnicodeDecodeError:
137
- # If ASCII decode fails, try UTF-8
138
- f.seek(0)
139
- unpickler = pickle.Unpickler(f)
140
- embeddings = unpickler.load()
141
 
142
- # Validate the loaded data
143
- if not isinstance(embeddings, dict):
144
- print(f"Error: Expected dict, got {type(embeddings)}")
145
- return None
 
146
 
147
- # Convert values to numpy arrays
148
- processed_embeddings = {}
149
- for key, value in embeddings.items():
150
- try:
151
- # Handle various input types
152
- if isinstance(value, np.ndarray):
153
- processed_embeddings[key] = value
154
- else:
155
- processed_embeddings[key] = np.array(value, dtype=np.float32)
156
- except Exception as e:
157
- print(f"Warning: Could not process embedding for {key}: {e}")
158
- continue
159
-
160
- if processed_embeddings:
161
- sample_key = next(iter(processed_embeddings))
162
- print(f"Data type: {type(processed_embeddings)}")
163
- print(f"Total embeddings loaded: {len(processed_embeddings)}")
164
- print(f"Sample embedding shape: {processed_embeddings[sample_key].shape}")
165
- return processed_embeddings
166
- else:
167
- print("Error: No valid embeddings were processed")
168
- return None
169
-
170
  except Exception as e:
171
- print(f"Error loading embeddings: {str(e)}")
172
- print("If using Git LFS, ensure you've run 'git lfs pull' to download the actual file")
173
  return None
174
 
175
-
176
  def load_documents_data():
177
  """Load document data with error handling"""
178
  try:
 
20
  import subprocess
21
  from typing import Dict, Optional
22
  import codecs
23
+ from huggingface_hub import hf_hub_download
24
 
25
  try:
26
  subprocess.run(['git', 'lfs', 'pull'], check=True)
 
94
  return False
95
 
96
 
97
+ def load_embeddings(repo_id: str) -> Optional[Dict[str, np.ndarray]]:
98
+ """Load embeddings using HuggingFace Hub"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  try:
100
+ # Download file from HF Hub
101
+ file_path = hf_hub_download(
102
+ repo_id=repo_id,
103
+ filename="embeddings.pkl",
104
+ repo_type="space"
105
+ )
106
+
107
+ # Load with custom unpickler
108
+ with open(file_path, 'rb') as f:
109
+ unpickler = pickle.Unpickler(f)
110
+ unpickler.encoding = 'ascii'
111
+ embeddings = unpickler.load()
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ if not isinstance(embeddings, dict):
114
+ return None
115
+
116
+ # Convert to numpy arrays
117
+ return {k: np.array(v, dtype=np.float32) for k, v in embeddings.items()}
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  except Exception as e:
120
+ print(f"Error loading embeddings: {e}")
 
121
  return None
122
 
 
123
  def load_documents_data():
124
  """Load document data with error handling"""
125
  try: