dwb2023 commited on
Commit
be27f18
1 Parent(s): 3414b66

Create model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +61 -0
model_utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import os
3
+ import torch
4
+ from transformers import BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM, LlavaNextForConditionalGeneration, LlavaForConditionalGeneration, PaliGemmaForConditionalGeneration, Idefics2ForConditionalGeneration
5
+ import spaces
6
+
7
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
8
+
9
+ def install_flash_attn():
10
+ subprocess.run(
11
+ "pip install flash-attn --no-build-isolation",
12
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
13
+ shell=True,
14
+ )
15
+
16
+ ARCHITECTURE_MAP = {
17
+ "LlavaNextForConditionalGeneration": LlavaNextForConditionalGeneration,
18
+ "LlavaForConditionalGeneration": LlavaForConditionalGeneration,
19
+ "PaliGemmaForConditionalGeneration": PaliGemmaForConditionalGeneration,
20
+ "Idefics2ForConditionalGeneration": Idefics2ForConditionalGeneration,
21
+ "AutoModelForCausalLM": AutoModelForCausalLM
22
+ }
23
+
24
+ @spaces.GPU
25
+ def get_model_summary(model_name):
26
+ try:
27
+ config = AutoConfig.from_pretrained(model_name)
28
+ architecture = config.architectures[0]
29
+ quantization_config = getattr(config, 'quantization_config', None)
30
+
31
+ if quantization_config:
32
+ bnb_config = BitsAndBytesConfig(
33
+ load_in_4bit=quantization_config.get('load_in_4bit', False),
34
+ load_in_8bit=quantization_config.get('load_in_8bit', False),
35
+ bnb_4bit_compute_dtype=quantization_config.get('bnb_4bit_compute_dtype', torch.float16),
36
+ bnb_4bit_quant_type=quantization_config.get('bnb_4bit_quant_type', 'nf4'),
37
+ bnb_4bit_use_double_quant=quantization_config.get('bnb_4bit_use_double_quant', False),
38
+ llm_int8_enable_fp32_cpu_offload=quantization_config.get('llm_int8_enable_fp32_cpu_offload', False),
39
+ llm_int8_has_fp16_weight=quantization_config.get('llm_int8_has_fp16_weight', False),
40
+ llm_int8_skip_modules=quantization_config.get('llm_int8_skip_modules', None),
41
+ llm_int8_threshold=quantization_config.get('llm_int8_threshold', 6.0),
42
+ )
43
+ else:
44
+ bnb_config = None
45
+
46
+ model_class = ARCHITECTURE_MAP.get(architecture, AutoModelForCausalLM)
47
+ model = model_class.from_pretrained(
48
+ model_name, config=bnb_config, trust_remote_code=True
49
+ )
50
+
51
+ if model and not quantization_config:
52
+ model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
53
+
54
+ model_summary = str(model) if model else "Model architecture not found."
55
+ return model_summary, ""
56
+ except ValueError as ve:
57
+ return "", f"ValueError: {ve}"
58
+ except EnvironmentError as ee:
59
+ return "", f"EnvironmentError: {ee}"
60
+ except Exception as e:
61
+ return "", str(e)