trungtienluong commited on
Commit
d42a22f
1 Parent(s): 45a71f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -8,23 +8,29 @@ from sklearn.model_selection import train_test_split
8
 
9
  MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
10
 
11
- model = AutoModelForCausalLM.from_pretrained(
12
- MODEL_NAME,
13
- device_map="auto",
14
- trust_remote_code=True
15
- )
 
 
 
 
 
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
  tokenizer.pad_token = tokenizer.eos_token
19
  model.gradient_checkpointing_enable()
20
 
21
- # Load the pre-trained model with PEFT
22
- peft_config = PeftConfig.from_pretrained("trungtienluong/experiments500czephymodelngay11t6l1")
23
- model = PeftModel.from_pretrained(model, "trungtienluong/experiments500czephymodelngay11t6l1")
24
-
25
- # Move the model to the appropriate device
26
- device = "cuda" if torch.cuda.is_available() else "cpu"
27
- model.to(device)
 
28
 
29
  # Load the dataset
30
  dataset = load_dataset("trungtienluong/500cau")
 
8
 
9
  MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
10
 
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ try:
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ MODEL_NAME,
16
+ device_map="auto",
17
+ trust_remote_code=True
18
+ ).to(device)
19
+ except Exception as e:
20
+ print(f"Error loading base model: {e}")
21
 
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
  tokenizer.pad_token = tokenizer.eos_token
24
  model.gradient_checkpointing_enable()
25
 
26
+ try:
27
+ # Load the pre-trained model with PEFT
28
+ peft_config = PeftConfig.from_pretrained("trungtienluong/experiments500czephymodelngay11t6l1")
29
+ model = PeftModel.from_pretrained(model, "trungtienluong/experiments500czephymodelngay11t6l1").to(device)
30
+ except KeyError as e:
31
+ print(f"KeyError during PEFT model loading: {e}")
32
+ except Exception as e:
33
+ print(f"Error loading PEFT model: {e}")
34
 
35
  # Load the dataset
36
  dataset = load_dataset("trungtienluong/500cau")