Grandediw commited on
Commit
40b886d
·
verified ·
1 Parent(s): 5e8b7be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -152
app.py CHANGED
@@ -1,187 +1,73 @@
 
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
- from diffusers import DiffusionPipeline
5
- import torch
6
-
7
- from diffusers import DiffusionPipeline
8
  import torch
9
- from huggingface_hub import login
10
-
11
- import os
12
- from huggingface_hub import login
13
 
14
- # Access the token from environment variables
15
  token = os.getenv("HUGGINGFACE_API_TOKEN")
16
  if not token:
17
- raise ValueError("API token is missing! Please set it in the Hugging Face Space Secrets.")
18
 
19
- # Log in to Hugging Face
20
- login(token)
21
-
22
-
23
- # Model details
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
26
-
27
- # Load model
28
- pipe = DiffusionPipeline.from_pretrained(
29
- "Grandediw/lora_model",
30
- torch_dtype=torch_dtype,
31
- use_auth_token=True # Enables private model access
32
- )
33
- pipe = pipe.to(device)
34
-
35
-
36
- # Constants
37
- MAX_SEED = np.iinfo(np.int32).max
38
- MAX_IMAGE_SIZE = 1024
39
 
40
- # Inference function
41
- def infer(
42
- prompt,
43
- negative_prompt,
44
- seed,
45
- randomize_seed,
46
- width,
47
- height,
48
- guidance_scale,
49
- num_inference_steps,
50
- ):
51
- if randomize_seed:
52
- seed = random.randint(0, MAX_SEED)
53
 
54
- generator = torch.Generator(device).manual_seed(seed)
 
 
55
 
56
- # Generate the image
57
- image = pipe(
58
- prompt=prompt,
59
- negative_prompt=negative_prompt,
60
- guidance_scale=guidance_scale,
61
- num_inference_steps=num_inference_steps,
62
- width=width,
63
- height=height,
64
- generator=generator,
65
- ).images[0]
66
 
67
- return image, seed
 
68
 
69
- # Example prompts
70
- examples = [
71
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
72
- "An astronaut riding a green horse",
73
- "A delicious ceviche cheesecake slice",
74
- ]
75
 
76
- # Improved CSS for better styling
77
  css = """
78
  #interface-container {
79
  margin: 0 auto;
80
  max-width: 700px;
81
- padding: 10px;
82
- box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1);
83
  border-radius: 10px;
84
  background-color: #f9f9f9;
 
85
  }
86
  #header {
87
  text-align: center;
88
  font-size: 1.5em;
 
89
  margin-bottom: 20px;
90
  color: #333;
91
  }
92
- #advanced-settings {
93
- background-color: #f1f1f1;
94
- padding: 10px;
95
- border-radius: 8px;
96
- }
97
  """
98
 
99
- # Gradio interface
100
  with gr.Blocks(css=css) as demo:
101
  with gr.Box(elem_id="interface-container"):
102
- gr.Markdown(
103
- """
104
- <div id="header">🖼️ Text-to-Image Generator</div>
105
- Generate high-quality images from your text prompts with the fine-tuned LoRA model.
106
- """
107
- )
108
-
109
- # Main input row
110
- with gr.Row():
111
- prompt = gr.Textbox(
112
- label="Prompt",
113
- placeholder="Describe the image you want to create...",
114
- lines=2,
115
- )
116
- run_button = gr.Button("Generate Image", variant="primary")
117
-
118
- # Output image display
119
- result = gr.Image(label="Generated Image").style(height="512px")
120
-
121
- # Advanced settings
122
- with gr.Accordion("Advanced Settings", open=False, elem_id="advanced-settings"):
123
- negative_prompt = gr.Textbox(
124
- label="Negative Prompt",
125
- placeholder="What to exclude from the image...",
126
- )
127
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
128
- seed = gr.Number(label="Seed", value=0, interactive=True)
129
-
130
- with gr.Row():
131
- width = gr.Slider(
132
- label="Image Width",
133
- minimum=256,
134
- maximum=MAX_IMAGE_SIZE,
135
- step=64,
136
- value=512,
137
- )
138
- height = gr.Slider(
139
- label="Image Height",
140
- minimum=256,
141
- maximum=MAX_IMAGE_SIZE,
142
- step=64,
143
- value=512,
144
- )
145
-
146
- with gr.Row():
147
- guidance_scale = gr.Slider(
148
- label="Guidance Scale",
149
- minimum=0.0,
150
- maximum=20.0,
151
- step=0.1,
152
- value=7.5,
153
- )
154
- num_inference_steps = gr.Slider(
155
- label="Steps",
156
- minimum=10,
157
- maximum=100,
158
- step=5,
159
- value=50,
160
- )
161
 
162
- # Examples
163
- gr.Examples(
164
- examples=examples,
165
- inputs=[prompt],
166
- outputs=[result],
167
- label="Try these prompts",
168
- )
169
 
170
- # Event handler
171
- run_button.click(
172
- fn=infer,
173
- inputs=[
174
- prompt,
175
- negative_prompt,
176
- seed,
177
- randomize_seed,
178
- width,
179
- height,
180
- guidance_scale,
181
- num_inference_steps,
182
- ],
183
- outputs=[result, seed],
184
- )
185
 
 
186
  if __name__ == "__main__":
187
  demo.launch()
 
1
+ import os
2
  import gradio as gr
 
 
 
 
 
 
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from safetensors.torch import load_file
 
 
6
 
7
+ # Load the Hugging Face API token
8
  token = os.getenv("HUGGINGFACE_API_TOKEN")
9
  if not token:
10
+ raise ValueError("HUGGINGFACE_API_TOKEN is not set. Please add it in the Secrets section of your Space.")
11
 
12
+ # Configure device and data type
 
 
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Load the tokenizer and model
16
+ model_repo = "Grandediw/lora_model"
17
+ tokenizer = AutoTokenizer.from_pretrained(model_repo, use_auth_token=True)
18
+ base_model = AutoModel.from_pretrained(model_repo, use_auth_token=True)
 
 
 
 
 
 
 
 
 
19
 
20
+ # Load LoRA adapter weights
21
+ lora_weights_path = "adapter_model.safetensors" # Ensure this file is present in the same directory
22
+ lora_weights = load_file(lora_weights_path)
23
 
24
+ # Apply LoRA weights to the base model
25
+ for name, param in base_model.named_parameters():
26
+ if name in lora_weights:
27
+ param.data += lora_weights[name].to(device, dtype=param.dtype)
 
 
 
 
 
 
28
 
29
+ # Move the model to the device
30
+ base_model = base_model.to(device)
31
 
32
+ # Inference function
33
+ def infer(prompt, negative_prompt=None):
34
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
35
+ outputs = base_model(**inputs)
36
+ return outputs.last_hidden_state.mean(dim=1).cpu().detach().numpy() # Placeholder return
 
37
 
38
+ # Gradio Interface
39
  css = """
40
  #interface-container {
41
  margin: 0 auto;
42
  max-width: 700px;
43
+ padding: 15px;
 
44
  border-radius: 10px;
45
  background-color: #f9f9f9;
46
+ box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1);
47
  }
48
  #header {
49
  text-align: center;
50
  font-size: 1.5em;
51
+ font-weight: bold;
52
  margin-bottom: 20px;
53
  color: #333;
54
  }
 
 
 
 
 
55
  """
56
 
 
57
  with gr.Blocks(css=css) as demo:
58
  with gr.Box(elem_id="interface-container"):
59
+ gr.Markdown("<div id='header'>LoRA Model Inference</div>")
60
+
61
+ # Input for prompt and run button
62
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
63
+ run_button = gr.Button("Generate Output", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ # Display output
66
+ output = gr.Textbox(label="Output")
 
 
 
 
 
67
 
68
+ # Connect button with inference
69
+ run_button.click(fn=infer, inputs=[prompt], outputs=[output])
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ # Launch the app
72
  if __name__ == "__main__":
73
  demo.launch()