jatingocodeo commited on
Commit
0e007bb
·
verified ·
1 Parent(s): d70db54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -114
app.py CHANGED
@@ -1,137 +1,127 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from peft import PeftModel, PeftConfig
5
- from PIL import Image
6
- import torchvision.datasets as datasets
7
- import numpy as np
8
- import os
9
 
10
- def load_model():
11
- # Create offload directory
12
- os.makedirs("offload", exist_ok=True)
 
 
13
 
14
- # Configure device map for memory efficiency
15
- device_map = {
16
- 'base_model.model.model.embed_tokens': 0,
17
- 'base_model.model.model.layers.0': 0,
18
- 'base_model.model.model.layers.1': 0,
19
- 'base_model.model.model.layers.2': 0,
20
- 'base_model.model.model.layers.3': 0,
21
- 'base_model.model.model.layers.4': 'cpu',
22
- 'base_model.model.model.layers.5': 'cpu',
23
- 'base_model.model.model.layers.6': 'cpu',
24
- 'base_model.model.model.layers.7': 'cpu',
25
- 'base_model.model.model.layers.8': 'cpu',
26
- 'base_model.model.model.norm': 'cpu',
27
- 'base_model.model.lm_head': 0,
28
- }
29
 
30
  base_model = AutoModelForCausalLM.from_pretrained(
31
- "microsoft/Phi-3-mini-4k-instruct",
32
- trust_remote_code=True,
33
- device_map=device_map, # Use custom device map
34
- torch_dtype=torch.float32,
35
- attn_implementation='eager',
36
- offload_folder="offload"
37
  )
38
 
39
- model = PeftModel.from_pretrained(
40
- base_model,
41
- "jatingocodeo/phi-vlm",
42
- device_map=device_map,
43
- offload_folder="offload"
44
- )
45
-
46
- tokenizer = AutoTokenizer.from_pretrained("jatingocodeo/phi-vlm")
47
-
48
  return model, tokenizer
49
 
50
- def generate_description(image, model, tokenizer):
51
- # Convert image to RGB if needed
52
- if image.mode != "RGB":
53
- image = image.convert("RGB")
54
-
55
- # Resize image to match training size (32x32)
56
- image = image.resize((32, 32))
57
-
58
- # Convert image to tensor and normalize
59
- image_tensor = torch.FloatTensor(np.array(image)).permute(2, 0, 1) / 255.0
60
-
61
- # Prepare prompt with image tensor
62
- prompt = f"""Below is an image. Please describe it in detail.
63
-
64
- Image: {image_tensor}
65
- Description: """
66
 
67
  # Tokenize input
68
- inputs = tokenizer(
69
- prompt,
70
- return_tensors="pt",
71
- padding=True,
72
- truncation=True,
73
- max_length=128
74
- ).to(model.device)
75
 
76
- # Generate description
77
  with torch.no_grad():
78
  outputs = model.generate(
79
- input_ids=inputs.input_ids,
80
- attention_mask=inputs.attention_mask,
81
- max_new_tokens=100,
82
- temperature=0.7,
83
- do_sample=True,
84
- top_p=0.9
 
85
  )
86
 
87
- # Decode and return the generated text
88
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
89
- return generated_text.split("Description: ")[-1].strip()
90
-
91
- # Load model
92
- print("Loading model...")
93
- model, tokenizer = load_model()
94
-
95
- # Get CIFAR10 examples
96
- def get_cifar_examples():
97
- cifar10_test = datasets.CIFAR10(root='./data', train=False, download=True)
98
- classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
99
- 'dog', 'frog', 'horse', 'ship', 'truck']
100
 
101
- examples = []
102
- used_classes = set()
 
 
 
 
 
 
 
103
 
104
- for idx in range(len(cifar10_test)):
105
- img, label = cifar10_test[idx]
106
- if classes[label] not in used_classes:
107
- img_path = f"examples/{classes[label]}_example.jpg"
108
- img.save(img_path)
109
- examples.append(img_path)
110
- used_classes.add(classes[label])
111
-
112
- if len(used_classes) == 10:
113
- break
 
 
 
114
 
115
- return examples
116
-
117
- # Create Gradio interface
118
- def process_image(image):
119
- return generate_description(image, model, tokenizer)
120
-
121
- # Get examples
122
- examples = get_cifar_examples()
123
-
124
- # Define interface
125
- iface = gr.Interface(
126
- fn=process_image,
127
- inputs=gr.Image(type="pil"),
128
- outputs=gr.Textbox(label="Generated Description"),
129
- title="Image Description Generator",
130
- description="""Upload an image and get a detailed description generated by our fine-tuned VLM model.
131
- Below are sample images from CIFAR10 dataset that you can try.""",
132
- examples=[[ex] for ex in examples]
133
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- # Launch the interface
136
  if __name__ == "__main__":
137
- iface.launch()
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from peft import PeftModel
 
 
 
 
5
 
6
+ # Load model and tokenizer
7
+ def load_model(model_id):
8
+ # First load the base model
9
+ base_model_id = "microsoft/phi-2"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
11
 
12
+ # Ensure tokenizer has a padding token
13
+ if tokenizer.pad_token is None:
14
+ tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  base_model = AutoModelForCausalLM.from_pretrained(
17
+ base_model_id,
18
+ torch_dtype=torch.float16,
19
+ device_map="auto",
20
+ trust_remote_code=True
 
 
21
  )
22
 
23
+ # Load and merge the LoRA adapter
24
+ model = PeftModel.from_pretrained(base_model, model_id)
 
 
 
 
 
 
 
25
  return model, tokenizer
26
 
27
+ def generate_response(instruction, model, tokenizer, max_length=200, temperature=0.7, top_p=0.9):
28
+ # Format the input text
29
+ input_text = instruction.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Tokenize input
32
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
 
 
 
 
 
 
33
 
34
+ # Generate response
35
  with torch.no_grad():
36
  outputs = model.generate(
37
+ **inputs,
38
+ max_new_tokens=max_length,
39
+ temperature=temperature,
40
+ top_p=top_p,
41
+ num_return_sequences=1,
42
+ pad_token_id=tokenizer.eos_token_id,
43
+ do_sample=True
44
  )
45
 
46
+ # Decode and return the response
47
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Extract only the response part (what comes after the instruction)
50
+ if len(input_text) < len(full_text):
51
+ response = full_text[len(input_text):].strip()
52
+ return response
53
+ return full_text.strip()
54
+
55
+ def create_demo(model_id):
56
+ # Load model and tokenizer
57
+ model, tokenizer = load_model(model_id)
58
 
59
+ # Define the interface
60
+ def process_input(instruction, max_length, temperature, top_p):
61
+ try:
62
+ return generate_response(
63
+ instruction,
64
+ model,
65
+ tokenizer,
66
+ max_length=max_length,
67
+ temperature=temperature,
68
+ top_p=top_p
69
+ )
70
+ except Exception as e:
71
+ return f"Error generating response: {str(e)}"
72
 
73
+ # Create the interface
74
+ demo = gr.Interface(
75
+ fn=process_input,
76
+ inputs=[
77
+ gr.Textbox(
78
+ label="Input Text",
79
+ placeholder="Enter your text here...",
80
+ lines=4
81
+ ),
82
+ gr.Slider(
83
+ minimum=50,
84
+ maximum=500,
85
+ value=150,
86
+ step=10,
87
+ label="Maximum Length"
88
+ ),
89
+ gr.Slider(
90
+ minimum=0.1,
91
+ maximum=1.0,
92
+ value=0.7,
93
+ step=0.1,
94
+ label="Temperature"
95
+ ),
96
+ gr.Slider(
97
+ minimum=0.1,
98
+ maximum=1.0,
99
+ value=0.9,
100
+ step=0.1,
101
+ label="Top P"
102
+ )
103
+ ],
104
+ outputs=gr.Textbox(label="Completion", lines=8),
105
+ title="Phi-2 GRPO Model Demo",
106
+ description="""This is a generative model trained using GRPO (Generative Reinforcement from Preference Optimization)
107
+ on the TLDR dataset. The model was trained to generate completions of around 150 characters.
108
+
109
+ You can adjust the generation parameters:
110
+ - **Maximum Length**: Controls the maximum length of the generated response
111
+ - **Temperature**: Higher values make the output more random, lower values make it more focused
112
+ - **Top P**: Controls the cumulative probability threshold for token sampling
113
+ """,
114
+ examples=[
115
+ ["The quick brown fox jumps over the lazy dog."],
116
+ ["In this tutorial, we will explore how to build a neural network for image classification."],
117
+ ["The best way to prepare for an interview is to"],
118
+ ["Python is a popular programming language because"]
119
+ ]
120
+ )
121
+ return demo
122
 
 
123
  if __name__ == "__main__":
124
+ # Use your model ID
125
+ model_id = "jatingocodeo/phi2-grpo"
126
+ demo = create_demo(model_id)
127
+ demo.launch()