willsh1997 commited on
Commit
12f7ab8
·
1 Parent(s): 18ce8a9

:clown_face: remove load in 4 bit, change dtype to bfloat16

Browse files
Files changed (1) hide show
  1. llm_translate_gradio.py +4 -4
llm_translate_gradio.py CHANGED
@@ -20,7 +20,7 @@ def load_models():
20
  nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
21
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
22
  "facebook/nllb-200-distilled-600M",
23
- load_in_4bit=True,
24
  device_map="auto"
25
  )
26
 
@@ -32,9 +32,9 @@ def load_models():
32
 
33
  llama_model = AutoModelForCausalLM.from_pretrained(
34
  model_id,
35
- load_in_4bit=True,
36
  device_map="auto",
37
- torch_dtype=torch.float16
38
  )
39
 
40
  print("Models loaded successfully!")
@@ -82,7 +82,7 @@ def translate_to_lang(input_str, target_lang):
82
  forced_bos_token_id=nllb_tokenizer.convert_tokens_to_ids(target_lang),
83
  max_new_tokens=512,
84
  do_sample=False,
85
- num_beams=2
86
  )
87
 
88
  output_str = nllb_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
 
20
  nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
21
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
22
  "facebook/nllb-200-distilled-600M",
23
+ # load_in_4bit=True,
24
  device_map="auto"
25
  )
26
 
 
32
 
33
  llama_model = AutoModelForCausalLM.from_pretrained(
34
  model_id,
35
+ # load_in_4bit=True,
36
  device_map="auto",
37
+ torch_dtype=torch.bfloat16
38
  )
39
 
40
  print("Models loaded successfully!")
 
82
  forced_bos_token_id=nllb_tokenizer.convert_tokens_to_ids(target_lang),
83
  max_new_tokens=512,
84
  do_sample=False,
85
+ num_beams=1
86
  )
87
 
88
  output_str = nllb_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]