PitterTMYT commited on
Commit
168745b
·
verified ·
1 Parent(s): 13a4d5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -13
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch, os
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  from huggingface_hub import login
5
 
6
  def init_model():
@@ -12,23 +12,15 @@ def init_model():
12
 
13
  login(hf_token, add_to_git_credential=True)
14
 
15
- tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_gemma2_10b", token=hf_token)
16
-
17
- # Настройка квантования
18
- quant_config = BitsAndBytesConfig(
19
- load_in_8bit=True, # Использование 8-битного квантования
20
- llm_int8_threshold=6.0, # Порог активации для 8-битных весов
21
- )
22
-
23
  model = AutoModelForCausalLM.from_pretrained(
24
  "IlyaGusev/saiga_gemma2_10b",
25
- token=hf_token,
26
  torch_dtype=torch.float16, # Использование float16 для уменьшения потребления памяти
27
- device_map="auto", # Автоматическое распределение модели на GPU
28
- quantization_config=quant_config, # Применение конфигурации квантования
29
  )
30
 
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
  model.to(device)
33
 
34
  def generate_response(prompt, max_length=100, temperature=0.7, top_p=0.85, repetition_penalty=1.1):
 
1
  import gradio as gr
2
  import torch, os
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from huggingface_hub import login
5
 
6
  def init_model():
 
12
 
13
  login(hf_token, add_to_git_credential=True)
14
 
15
+ tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_gemma2_10b", use_auth_token=hf_token)
 
 
 
 
 
 
 
16
  model = AutoModelForCausalLM.from_pretrained(
17
  "IlyaGusev/saiga_gemma2_10b",
18
+ use_auth_token=hf_token,
19
  torch_dtype=torch.float16, # Использование float16 для уменьшения потребления памяти
20
+ low_cpu_mem_usage=True # Настройка для уменьшения использования памяти на CPU
 
21
  )
22
 
23
+ device = torch.device("cpu") # Использование CPU
24
  model.to(device)
25
 
26
  def generate_response(prompt, max_length=100, temperature=0.7, top_p=0.85, repetition_penalty=1.1):