licesma commited on
Commit
30ff71b
·
1 Parent(s): 6be6272

Move paht to a variable

Browse files
Files changed (1) hide show
  1. demo_rag.py +3 -2
demo_rag.py CHANGED
@@ -4,10 +4,11 @@ import torch
4
  import time
5
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
7
 
8
  # Load tokenizer and model
9
- tokenizer = AutoTokenizer.from_pretrained("./deepseek-coder-1.3b-instruct")
10
- model = AutoModelForCausalLM.from_pretrained("./deepseek-coder-1.3b-instruct", torch_dtype=torch.bfloat16, device_map=device)
11
 
12
  # Initialize RAG and add schema docs
13
  retriever = SQLMetadataRetriever()
 
4
  import time
5
 
6
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ pretrain_path = "./deepseek-coder-1.3b-instruct"
8
 
9
  # Load tokenizer and model
10
+ tokenizer = AutoTokenizer.from_pretrained(pretrain_path)
11
+ model = AutoModelForCausalLM.from_pretrained(pretrain_path, torch_dtype=torch.bfloat16, device_map=device)
12
 
13
  # Initialize RAG and add schema docs
14
  retriever = SQLMetadataRetriever()