Maximofn commited on
Commit
b756054
·
1 Parent(s): 56ffcf9

Improve model loading with device-specific configuration and error handling

Browse files

- Add try-except block for robust model loading
- Implement separate loading strategies for CUDA and CPU devices
- Include low CPU memory usage option for CUDA
- Add informative print statements for device and loading status
- Enhance error handling during model initialization

Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -16,15 +16,30 @@ print("Cargando modelo y tokenizer...")
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
18
 
19
- # Load the model in BF16 format for better performance and lower memory usage
20
- tokenizer = AutoTokenizer.from_pretrained(model_name)
21
- model = AutoModelForCausalLM.from_pretrained(
22
- model_name,
23
- torch_dtype=torch.bfloat16,
24
- device_map="auto" # This will automatically distribute the model across available GPUs
25
- )
26
-
27
- print(f"Modelo cargado en dispositivo: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Define the function that calls the model
30
  def call_model(state: MessagesState):
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
18
 
19
+ try:
20
+ # Load the model in BF16 format for better performance and lower memory usage
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+
23
+ if device == "cuda":
24
+ print("Usando GPU para el modelo...")
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ torch_dtype=torch.bfloat16,
28
+ device_map="auto",
29
+ low_cpu_mem_usage=True
30
+ )
31
+ else:
32
+ print("Usando CPU para el modelo...")
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_name,
35
+ device_map={"": device},
36
+ torch_dtype=torch.float32
37
+ )
38
+
39
+ print(f"Modelo cargado exitosamente en: {device}")
40
+ except Exception as e:
41
+ print(f"Error al cargar el modelo: {str(e)}")
42
+ raise
43
 
44
  # Define the function that calls the model
45
  def call_model(state: MessagesState):