sbicy commited on
Commit
a919327
·
verified ·
1 Parent(s): 1e8fe07

fixed paths

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -19,22 +19,22 @@ wandb.login(key=wandb_api_key)
19
 
20
  # Define function to load model and pipeline dynamically
21
  def load_pipeline(model_name, fine_tuned=False):
22
- # Set model paths for pre-trained and fine-tuned versions
23
  paths = {
24
- "gpt2": ("gpt2-medium", "path/to/finetuned_gpt2"),
25
- "gpt_neo": ("EleutherAI/gpt-neo-1.3B", "path/to/finetuned_gpt_neo"),
26
- "gpt_j": ("EleutherAI/gpt-j-6B", "path/to/finetuned_gpt_j")
27
  }
28
 
29
- pretrained_model_name, finetuned_model_path = paths[model_name]
30
- model_path = finetuned_model_path if fine_tuned else pretrained_model_name
31
-
32
- # Load model and tokenizer
33
  model = AutoModelForCausalLM.from_pretrained(model_path)
34
  tokenizer = AutoTokenizer.from_pretrained(model_path)
35
  tokenizer.pad_token = tokenizer.eos_token
36
-
37
- # Set up pipeline with GPU
38
  return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
39
 
40
  # Add the GPU decorator to the generate function
 
19
 
20
  # Define function to load model and pipeline dynamically
21
  def load_pipeline(model_name, fine_tuned=False):
22
+ # Use the Hugging Face repo paths for each model
23
  paths = {
24
+ "gpt2": ("gpt2-medium", "sbicy/finetuned-gpt2"),
25
+ "gpt_neo": ("EleutherAI/gpt-neo-1.3B", "sbicy/finetuned-gpt-neo"),
26
+ "gpt_j": ("EleutherAI/gpt-j-6B", "sbicy/finetuned-gpt-j")
27
  }
28
 
29
+ pretrained_model_name, finetuned_model_repo = paths[model_name]
30
+ model_path = finetuned_model_repo if fine_tuned else pretrained_model_name
31
+
32
+ # Load model and tokenizer from Hugging Face Hub
33
  model = AutoModelForCausalLM.from_pretrained(model_path)
34
  tokenizer = AutoTokenizer.from_pretrained(model_path)
35
  tokenizer.pad_token = tokenizer.eos_token
36
+
37
+ # Set up pipeline with the specified device
38
  return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
39
 
40
  # Add the GPU decorator to the generate function