nmarafo commited on
Commit
5155d78
verified
1 Parent(s): 3f907da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -22
app.py CHANGED
@@ -7,27 +7,8 @@ import os
7
  # Cargar el token de Hugging Face desde los secretos
8
  token = os.environ.get("HF_TOKEN")
9
 
10
- # IDs del modelo y el tokenizador
11
- model_id = "PrunaAI/google-shieldgemma-2b-bnb-4bit-smashed"
12
- tokenizer_id = "google/shieldgemma-2b"
13
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, token=token)
14
-
15
- # Configurar BitsAndBytes para cuantizaci贸n en 4 bits
16
- quantization_config = BitsAndBytesConfig(
17
- load_in_4bit=True,
18
- bnb_4bit_use_double_quant=True,
19
- bnb_4bit_quant_type="nf4",
20
- bnb_4bit_compute_dtype=torch.bfloat16
21
- )
22
-
23
- # Cargar el modelo con la configuraci贸n de cuantizaci贸n
24
- model = AutoModelForCausalLM.from_pretrained(
25
- model_id,
26
- quantization_config=quantization_config,
27
- trust_remote_code=True,
28
- device_map="auto",
29
- token=token
30
- )
31
 
32
  # Funci贸n para generar el prompt dependiendo del idioma seleccionado
33
  def generar_prompt(message, tipo_clasificacion, idioma):
@@ -95,7 +76,7 @@ def generar_prompt(message, tipo_clasificacion, idioma):
95
  def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, language, harm_type):
96
  prompt = generar_prompt(message, harm_type, language)
97
 
98
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
99
 
100
  with torch.no_grad():
101
  logits = model(**inputs).logits
 
7
  # Cargar el token de Hugging Face desde los secretos
8
  token = os.environ.get("HF_TOKEN")
9
 
10
+ model = AutoModelForCausalLM.from_pretrained("PrunaAI/google-shieldgemma-2b-bnb-4bit-smashed", trust_remote_code=True, device_map='auto')
11
+ tokenizer = AutoTokenizer.from_pretrained("google/shieldgemma-2b")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Funci贸n para generar el prompt dependiendo del idioma seleccionado
14
  def generar_prompt(message, tipo_clasificacion, idioma):
 
76
  def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, language, harm_type):
77
  prompt = generar_prompt(message, harm_type, language)
78
 
79
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)["input_ids"]
80
 
81
  with torch.no_grad():
82
  logits = model(**inputs).logits