Alina Lozovskaya commited on
Commit
9469eae
1 Parent(s): 0e60add

Improve model size calculation

Browse files
Files changed (1) hide show
  1. backend/app/utils/model_validation.py +42 -40
backend/app/utils/model_validation.py CHANGED
@@ -5,10 +5,12 @@ import re
5
  from typing import Tuple, Optional, Dict, Any
6
  import aiohttp
7
  from huggingface_hub import HfApi, ModelCard, hf_hub_download
 
8
  from transformers import AutoConfig, AutoTokenizer
9
  from app.config.base import HF_TOKEN, API
10
  from app.utils.logging import LogFormatter
11
 
 
12
  logger = logging.getLogger(__name__)
13
 
14
  class ModelValidator:
@@ -54,78 +56,78 @@ class ModelValidator:
54
  logger.error(LogFormatter.error(error_msg, e))
55
  return False, str(e), None
56
 
57
- async def get_safetensors_metadata(self, model_id: str, filename: str = "model.safetensors") -> Optional[Dict]:
58
  """Get metadata from a safetensors file"""
59
  try:
60
- url = f"{API['HUB']}/{model_id}/raw/main/{filename}"
61
- async with aiohttp.ClientSession() as session:
62
- async with session.get(url, headers=self.headers) as response:
63
- if response.status == 200:
64
- # Read only the first 32KB to get the metadata
65
- header = await response.content.read(32768)
66
- # Parse metadata length from the first 8 bytes
67
- metadata_len = int.from_bytes(header[:8], byteorder='little')
68
- metadata_bytes = header[8:8+metadata_len]
69
- return json.loads(metadata_bytes)
70
- return None
 
 
 
 
 
 
71
  except Exception as e:
72
- logger.warning(f"Failed to get safetensors metadata: {str(e)}")
73
  return None
74
-
75
  async def get_model_size(
76
  self,
77
  model_info: Any,
78
  precision: str,
79
- base_model: str
 
80
  ) -> Tuple[Optional[float], Optional[str]]:
81
  """Get model size in billions of parameters"""
82
  try:
83
  logger.info(LogFormatter.info(f"Checking model size for {model_info.modelId}"))
84
-
85
  # Check if model is adapter
86
  is_adapter = any(s.rfilename == "adapter_config.json" for s in model_info.siblings if hasattr(s, 'rfilename'))
87
-
88
  # Try to get size from safetensors first
89
  model_size = None
90
-
91
  if is_adapter and base_model:
92
  # For adapters, we need both adapter and base model sizes
93
- adapter_meta = await self.get_safetensors_metadata(model_info.id, "adapter_model.safetensors")
94
- base_meta = await self.get_safetensors_metadata(base_model)
95
-
96
  if adapter_meta and base_meta:
97
- adapter_size = sum(int(v.split(',')[0]) for v in adapter_meta.get("tensor_metadata", {}).values())
98
- base_size = sum(int(v.split(',')[0]) for v in base_meta.get("tensor_metadata", {}).values())
99
  model_size = (adapter_size + base_size) / (2 * 1e9) # Convert to billions, assuming float16
100
  else:
101
  # For regular models, just get the model size
102
- meta = await self.get_safetensors_metadata(model_info.id)
103
  if meta:
104
- total_params = sum(int(v.split(',')[0]) for v in meta.get("tensor_metadata", {}).values())
105
  model_size = total_params / (2 * 1e9) # Convert to billions, assuming float16
106
-
107
  if model_size is None:
108
- # Fallback: Try to get size from model name
109
- size_pattern = re.compile(r"(\d+\.?\d*)b") # Matches patterns like "7b", "13b", "1.1b"
110
- size_match = re.search(size_pattern, model_info.id.lower())
111
-
112
- if size_match:
113
- size_str = size_match.group(1)
114
- model_size = float(size_str)
115
- else:
116
- return None, "Could not determine model size from safetensors or model name"
117
-
118
  # Adjust size for GPTQ models
119
  size_factor = 8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
120
  model_size = round(size_factor * model_size, 3)
121
-
122
  logger.info(LogFormatter.success(f"Model size: {model_size}B parameters"))
123
  return model_size, None
124
-
125
  except Exception as e:
126
- error_msg = "Failed to get model size"
127
- logger.error(LogFormatter.error(error_msg, e))
128
  return None, str(e)
 
129
 
130
  async def check_chat_template(
131
  self,
 
5
  from typing import Tuple, Optional, Dict, Any
6
  import aiohttp
7
  from huggingface_hub import HfApi, ModelCard, hf_hub_download
8
+ from huggingface_hub import hf_api
9
  from transformers import AutoConfig, AutoTokenizer
10
  from app.config.base import HF_TOKEN, API
11
  from app.utils.logging import LogFormatter
12
 
13
+
14
  logger = logging.getLogger(__name__)
15
 
16
  class ModelValidator:
 
56
  logger.error(LogFormatter.error(error_msg, e))
57
  return False, str(e), None
58
 
59
+ async def get_safetensors_metadata(self, model_id: str, is_adapter: bool = False, revision: str = "main") -> Optional[Dict]:
60
  """Get metadata from a safetensors file"""
61
  try:
62
+ if is_adapter:
63
+ metadata = await asyncio.to_thread(
64
+ hf_api.parse_safetensors_file_metadata,
65
+ model_id,
66
+ "adapter_model.safetensors",
67
+ token=self.token,
68
+ revision=revision,
69
+ )
70
+ else:
71
+ metadata = await asyncio.to_thread(
72
+ hf_api.get_safetensors_metadata,
73
+ repo_id=model_id,
74
+ token=self.token,
75
+ revision=revision,
76
+ )
77
+ return metadata
78
+
79
  except Exception as e:
80
+ logger.error(f"Failed to get safetensors metadata: {str(e)}")
81
  return None
82
+
83
  async def get_model_size(
84
  self,
85
  model_info: Any,
86
  precision: str,
87
+ base_model: str,
88
+ revision: str
89
  ) -> Tuple[Optional[float], Optional[str]]:
90
  """Get model size in billions of parameters"""
91
  try:
92
  logger.info(LogFormatter.info(f"Checking model size for {model_info.modelId}"))
93
+
94
  # Check if model is adapter
95
  is_adapter = any(s.rfilename == "adapter_config.json" for s in model_info.siblings if hasattr(s, 'rfilename'))
96
+
97
  # Try to get size from safetensors first
98
  model_size = None
99
+
100
  if is_adapter and base_model:
101
  # For adapters, we need both adapter and base model sizes
102
+ adapter_meta = await self.get_safetensors_metadata(model_info.id, is_adapter=True, revision=revision)
103
+ base_meta = await self.get_safetensors_metadata(base_model, revision="main")
104
+
105
  if adapter_meta and base_meta:
106
+ adapter_size = sum(adapter_meta.parameter_count.values())
107
+ base_size = sum(base_meta.parameter_count.values())
108
  model_size = (adapter_size + base_size) / (2 * 1e9) # Convert to billions, assuming float16
109
  else:
110
  # For regular models, just get the model size
111
+ meta = await self.get_safetensors_metadata(model_info.id, revision=revision)
112
  if meta:
113
+ total_params = sum(meta.parameter_count.values())
114
  model_size = total_params / (2 * 1e9) # Convert to billions, assuming float16
115
+
116
  if model_size is None:
117
+ # If model size could not be determined, return an error
118
+ return None, "Model size could not be determined"
119
+
 
 
 
 
 
 
 
120
  # Adjust size for GPTQ models
121
  size_factor = 8 if (precision == "GPTQ" or "gptq" in model_info.id.lower()) else 1
122
  model_size = round(size_factor * model_size, 3)
123
+
124
  logger.info(LogFormatter.success(f"Model size: {model_size}B parameters"))
125
  return model_size, None
126
+
127
  except Exception as e:
128
+ logger.error(LogFormatter.error(f"Error while determining model size: {e}"))
 
129
  return None, str(e)
130
+
131
 
132
  async def check_chat_template(
133
  self,