jcrissa commited on
Commit
b80a390
·
1 Parent(s): 3c2f453

edit app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -12
app.py CHANGED
@@ -7,11 +7,10 @@ from transformers import AutoTokenizer
7
  # Load your fine-tuned Phi-3 model from Hugging Face
8
  MODEL_NAME = "jcrissa/phi3-new-t2i"
9
 
10
- # Check if CUDA is available, otherwise fall back to CPU
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
12
 
13
- # Function to load the Phi-3 model and tokenizer
14
- @spaces.GPU # Reintroduced spaces.GPU decorator for GPU setup
15
  def load_phi3_model():
16
  try:
17
  # Load the Phi-3 model and tokenizer from Hugging Face
@@ -22,6 +21,9 @@ def load_phi3_model():
22
  )
23
  model.to(device)
24
 
 
 
 
25
  # Configure tokenizer settings
26
  tokenizer.pad_token = tokenizer.eos_token
27
  tokenizer.padding_side = "left"
@@ -37,6 +39,7 @@ phi3_model, phi3_tokenizer = load_phi3_model()
37
  if phi3_model is None or phi3_tokenizer is None:
38
  raise RuntimeError("Model and tokenizer could not be loaded. Please check the Hugging Face model path or network connection.")
39
 
 
40
  # Function to generate text using Phi-3
41
  def generate(plain_text):
42
  try:
@@ -44,12 +47,11 @@ def generate(plain_text):
44
  input_ids = phi3_tokenizer(plain_text.strip(), return_tensors="pt").input_ids.to(device)
45
  eos_id = phi3_tokenizer.eos_token_id
46
 
47
- # Generate the output from the model
48
  outputs = phi3_model.generate(
49
  input_ids,
50
- do_sample=True,
51
- max_new_tokens=75,
52
- num_beams=8,
53
  num_return_sequences=1,
54
  eos_token_id=eos_id,
55
  pad_token_id=eos_id,
@@ -62,18 +64,18 @@ def generate(plain_text):
62
  except Exception as e:
63
  return f"Error during text generation: {e}"
64
 
 
65
  # Setup Gradio Interface
66
  txt = grad.Textbox(lines=1, label="Input Text", placeholder="Enter your prompt")
67
  out = grad.Textbox(lines=1, label="Generated Text")
68
 
69
- # Launch Gradio Interface with ZeroGPU-compatible setup
70
- gr.Interface(
71
  fn=generate,
72
  inputs=txt,
73
  outputs=out,
74
  title="Fine-Tuned Phi-3 Model",
75
  description="This demo uses a fine-tuned Phi-3 model to optimize text prompts.",
76
- allow_flagging="never",
77
  cache_examples=False,
78
  theme="default"
79
- ).launch(enable_queue=True, debug=True)
 
7
  # Load your fine-tuned Phi-3 model from Hugging Face
8
  MODEL_NAME = "jcrissa/phi3-new-t2i"
9
 
10
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ device = "cuda"
13
 
 
 
14
  def load_phi3_model():
15
  try:
16
  # Load the Phi-3 model and tokenizer from Hugging Face
 
21
  )
22
  model.to(device)
23
 
24
+ # Prepare the model for inference
25
+ model = FastLanguageModel.for_inference(model) # This is the necessary line
26
+
27
  # Configure tokenizer settings
28
  tokenizer.pad_token = tokenizer.eos_token
29
  tokenizer.padding_side = "left"
 
39
  if phi3_model is None or phi3_tokenizer is None:
40
  raise RuntimeError("Model and tokenizer could not be loaded. Please check the Hugging Face model path or network connection.")
41
 
42
+ @spaces.GPU(duration=120)
43
  # Function to generate text using Phi-3
44
  def generate(plain_text):
45
  try:
 
47
  input_ids = phi3_tokenizer(plain_text.strip(), return_tensors="pt").input_ids.to(device)
48
  eos_id = phi3_tokenizer.eos_token_id
49
 
50
+ # Generate the output from the model using sampling instead of beam search
51
  outputs = phi3_model.generate(
52
  input_ids,
53
+ do_sample=True, # Use sampling instead of beam search
54
+ max_new_tokens=75,
 
55
  num_return_sequences=1,
56
  eos_token_id=eos_id,
57
  pad_token_id=eos_id,
 
64
  except Exception as e:
65
  return f"Error during text generation: {e}"
66
 
67
+
68
  # Setup Gradio Interface
69
  txt = grad.Textbox(lines=1, label="Input Text", placeholder="Enter your prompt")
70
  out = grad.Textbox(lines=1, label="Generated Text")
71
 
72
+ grad.Interface(
 
73
  fn=generate,
74
  inputs=txt,
75
  outputs=out,
76
  title="Fine-Tuned Phi-3 Model",
77
  description="This demo uses a fine-tuned Phi-3 model to optimize text prompts.",
78
+ flagging_mode="never", # Replace `allow_flagging` with `flagging_mode`
79
  cache_examples=False,
80
  theme="default"
81
+ ).launch(share=True) # Use `queue=True` instead of `enable_queue`