jcrissa commited on
Commit
bde3b1d
·
1 Parent(s): 33e84ee
Files changed (1) hide show
  1. app.py +24 -35
app.py CHANGED
@@ -3,64 +3,53 @@ import torch
3
  from unsloth import FastLanguageModel
4
  from transformers import AutoTokenizer
5
 
6
- # Load your fine-tuned Phi-3 model
7
- def load_phi3():
8
- model_name = "jcrissa/phi3-new-t2i" # Your trained model
9
- max_seq_length = 4096 # Ensure correct max length
10
-
11
- # Load fine-tuned model
12
  model, tokenizer = FastLanguageModel.from_pretrained(
13
- model_name,
14
- max_seq_length=max_seq_length,
15
- dtype=None, # Uses default torch dtype (float16 or bfloat16 if available)
16
- load_in_4bit=True # Uses 4-bit quantization for efficiency
17
  )
18
 
19
- tokenizer.pad_token = tokenizer.eos_token # Ensure padding is set
20
  tokenizer.padding_side = "left"
21
-
22
  return model, tokenizer
23
 
24
- # Initialize model and tokenizer
25
- phi3_model, phi3_tokenizer = load_phi3()
26
 
27
- # Function to generate prompts
28
  def generate(plain_text):
29
- input_ids = phi3_tokenizer(plain_text.strip(), return_tensors="pt").input_ids.cuda() # Move to GPU if available
30
  eos_id = phi3_tokenizer.eos_token_id
31
 
32
  outputs = phi3_model.generate(
33
  input_ids,
34
- do_sample=True,
35
  max_new_tokens=75,
36
- num_beams=5,
37
  num_return_sequences=1,
38
  eos_token_id=eos_id,
39
- pad_token_id=eos_id,
40
- length_penalty=1.0
41
  )
42
 
43
- output_texts = phi3_tokenizer.batch_decode(outputs, skip_special_tokens=True)
44
- return output_texts[0]
45
 
46
- # Gradio UI
47
- txt = grad.Textbox(lines=1, label="Initial Text", placeholder="Input Prompt")
48
- out = grad.Textbox(lines=1, label="Optimized Prompt")
49
- examples = [
50
- "A rabbit is wearing a space suit",
51
- "Several railroad tracks with one train passing by",
52
- "The roof is wet from the rain",
53
- "Cats dancing in a space club"
54
- ]
55
 
56
  grad.Interface(
57
  fn=generate,
58
  inputs=txt,
59
  outputs=out,
60
- title="Phi-3 Prompt Generator",
61
- description="Fine-tuned Phi-3 model (`jcrissa/phi3-new-t2i`) for text-to-image prompt generation.",
62
- examples=examples,
63
- allow_flagging='never',
64
  cache_examples=False,
65
  theme="default"
66
  ).launch(enable_queue=True, debug=True)
 
3
  from unsloth import FastLanguageModel
4
  from transformers import AutoTokenizer
5
 
6
+ # Load your fine-tuned Phi-3 model from Hugging Face
7
+ MODEL_NAME = "jcrissa/phi3-new-t2i"
8
+
9
+ def load_phi3_model():
 
 
10
  model, tokenizer = FastLanguageModel.from_pretrained(
11
+ MODEL_NAME,
12
+ max_seq_length=4096, # Ensure it matches your fine-tuning
13
+ dtype=None # Use `torch.float16` if running on GPU
 
14
  )
15
 
16
+ tokenizer.pad_token = tokenizer.eos_token
17
  tokenizer.padding_side = "left"
18
+
19
  return model, tokenizer
20
 
21
+ phi3_model, phi3_tokenizer = load_phi3_model()
 
22
 
23
+ # Function to generate text using Phi-3
24
  def generate(plain_text):
25
+ input_ids = phi3_tokenizer(plain_text.strip(), return_tensors="pt").input_ids
26
  eos_id = phi3_tokenizer.eos_token_id
27
 
28
  outputs = phi3_model.generate(
29
  input_ids,
30
+ do_sample=True,
31
  max_new_tokens=75,
32
+ num_beams=8,
33
  num_return_sequences=1,
34
  eos_token_id=eos_id,
35
+ pad_token_id=eos_id,
36
+ length_penalty=-1.0
37
  )
38
 
39
+ output_text = phi3_tokenizer.decode(outputs[0], skip_special_tokens=True)
40
+ return output_text.strip()
41
 
42
+ # Setup Gradio Interface
43
+ txt = grad.Textbox(lines=1, label="Input Text", placeholder="Enter your prompt")
44
+ out = grad.Textbox(lines=1, label="Generated Text")
 
 
 
 
 
 
45
 
46
  grad.Interface(
47
  fn=generate,
48
  inputs=txt,
49
  outputs=out,
50
+ title="Fine-Tuned Phi-3 Model",
51
+ description="This demo uses a fine-tuned Phi-3 model to optimize text prompts.",
52
+ allow_flagging="never",
 
53
  cache_examples=False,
54
  theme="default"
55
  ).launch(enable_queue=True, debug=True)