sbicy commited on
Commit
1e8fe07
·
verified ·
1 Parent(s): a830b73

trying to update to use Zero GPU

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import gradio as gr
@@ -18,27 +19,27 @@ wandb.login(key=wandb_api_key)
18
 
19
  # Define function to load model and pipeline dynamically
20
  def load_pipeline(model_name, fine_tuned=False):
21
- # Define model paths for pre-trained and fine-tuned versions
22
  paths = {
23
- "gpt2": ("gpt2-medium", "sbicy/finetuned-gpt2"),
24
- "gpt_neo": ("EleutherAI/gpt-neo-1.3B", "sbicy/finetuned-gpt-neo"),
25
- "gpt_j": ("EleutherAI/gpt-j-6B", "sbicy/finetuned-gpt-j")
26
  }
27
 
28
  pretrained_model_name, finetuned_model_path = paths[model_name]
29
  model_path = finetuned_model_path if fine_tuned else pretrained_model_name
30
 
31
  # Load model and tokenizer
32
- model = AutoModelForCausalLM.from_pretrained(model_path, use_auth_token=hf_api_key)
33
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_auth_token=hf_api_key)
34
  tokenizer.pad_token = tokenizer.eos_token
35
 
36
- # Set up pipeline with specified device
37
- return pipeline("text-generation", model=model, tokenizer=tokenizer)
38
 
39
- # Define Gradio app function
 
40
  def compare_single_model(prompt, model_choice, temperature, top_p, max_length):
41
- # Load pre-trained and fine-tuned pipelines
42
  pretrained_pipeline = load_pipeline(model_choice, fine_tuned=False)
43
  finetuned_pipeline = load_pipeline(model_choice, fine_tuned=True)
44
 
@@ -46,9 +47,13 @@ def compare_single_model(prompt, model_choice, temperature, top_p, max_length):
46
  pretrained_response = pretrained_pipeline(prompt, temperature=temperature, top_p=top_p, max_length=int(max_length))[0]["generated_text"]
47
  finetuned_response = finetuned_pipeline(prompt, temperature=temperature, top_p=top_p, max_length=int(max_length))[0]["generated_text"]
48
 
 
 
 
 
49
  return pretrained_response, finetuned_response
50
 
51
- # Gradio interface setup
52
  interface = gr.Interface(
53
  fn=compare_single_model,
54
  inputs=[
 
1
+ import spaces
2
  import os
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  import gradio as gr
 
19
 
20
  # Define function to load model and pipeline dynamically
21
  def load_pipeline(model_name, fine_tuned=False):
22
+ # Set model paths for pre-trained and fine-tuned versions
23
  paths = {
24
+ "gpt2": ("gpt2-medium", "path/to/finetuned_gpt2"),
25
+ "gpt_neo": ("EleutherAI/gpt-neo-1.3B", "path/to/finetuned_gpt_neo"),
26
+ "gpt_j": ("EleutherAI/gpt-j-6B", "path/to/finetuned_gpt_j")
27
  }
28
 
29
  pretrained_model_name, finetuned_model_path = paths[model_name]
30
  model_path = finetuned_model_path if fine_tuned else pretrained_model_name
31
 
32
  # Load model and tokenizer
33
+ model = AutoModelForCausalLM.from_pretrained(model_path)
34
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
35
  tokenizer.pad_token = tokenizer.eos_token
36
 
37
+ # Set up pipeline with GPU
38
+ return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
39
 
40
+ # Add the GPU decorator to the generate function
41
+ @spaces.GPU(duration=120) # Specify duration as needed
42
  def compare_single_model(prompt, model_choice, temperature, top_p, max_length):
 
43
  pretrained_pipeline = load_pipeline(model_choice, fine_tuned=False)
44
  finetuned_pipeline = load_pipeline(model_choice, fine_tuned=True)
45
 
 
47
  pretrained_response = pretrained_pipeline(prompt, temperature=temperature, top_p=top_p, max_length=int(max_length))[0]["generated_text"]
48
  finetuned_response = finetuned_pipeline(prompt, temperature=temperature, top_p=top_p, max_length=int(max_length))[0]["generated_text"]
49
 
50
+ # Free up memory after use
51
+ del pretrained_pipeline, finetuned_pipeline
52
+ torch.cuda.empty_cache()
53
+
54
  return pretrained_response, finetuned_response
55
 
56
+ # Gradio interface
57
  interface = gr.Interface(
58
  fn=compare_single_model,
59
  inputs=[