DigiP-AI commited on
Commit
5489ddf
·
verified ·
1 Parent(s): f7ccfea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -205
app.py CHANGED
@@ -1,209 +1,104 @@
1
- import re
 
 
 
 
2
  import os
3
- import yaml
4
- import tempfile
5
- import subprocess
6
- from pathlib import Path
7
-
8
  import torch
9
- import gradio as gr
10
-
11
- #from src.flux.xflux_pipeline import XFluxPipeline
12
-
13
-
14
- def list_dirs(path):
15
- if path is None or path == "None" or path == "":
16
- return
17
-
18
- if not os.path.exists(path):
19
- path = os.path.dirname(path)
20
- if not os.path.exists(path):
21
- return
22
-
23
- if not os.path.isdir(path):
24
- path = os.path.dirname(path)
25
-
26
- def natural_sort_key(s, regex=re.compile("([0-9]+)")):
27
- return [
28
- int(text) if text.isdigit() else text.lower() for text in regex.split(s)
29
- ]
30
-
31
- subdirs = [
32
- (item, os.path.join(path, item))
33
- for item in os.listdir(path)
34
- if os.path.isdir(os.path.join(path, item))
35
- ]
36
- subdirs = [
37
- filename
38
- for item, filename in subdirs
39
- if item[0] != "." and item not in ["__pycache__"]
40
- ]
41
- subdirs = sorted(subdirs, key=natural_sort_key)
42
- if os.path.dirname(path) != "":
43
- dirs = [os.path.dirname(path), path] + subdirs
44
- else:
45
- dirs = [path] + subdirs
46
-
47
- if os.sep == "\\":
48
- dirs = [d.replace("\\", "/") for d in dirs]
49
- for d in dirs:
50
- yield d
51
-
52
- def list_train_data_dirs():
53
- current_train_data_dir = "."
54
- return list(list_dirs(current_train_data_dir))
55
-
56
- def update_config(d, u):
57
- for k, v in u.items():
58
- if isinstance(v, dict):
59
- d[k] = update_config(d.get(k, {}), v)
60
- else:
61
- # convert Gradio components to strings
62
- if hasattr(v, 'value'):
63
- d[k] = str(v.value)
64
- else:
65
- try:
66
- d[k] = int(v)
67
- except (TypeError, ValueError):
68
- d[k] = str(v)
69
- return d
70
-
71
- def start_lora_training(
72
- data_dir: str, output_dir: str, lr: float, steps: int, rank: int
73
- ):
74
- inputs = {
75
- "data_config": {
76
- "img_dir": data_dir,
77
- },
78
- "output_dir": output_dir,
79
- "learning_rate": lr,
80
- "rank": rank,
81
- "max_train_steps": steps,
82
- }
83
-
84
- if not os.path.exists(output_dir):
85
- os.makedirs(output_dir)
86
- print(f"Creating folder {output_dir} for the output checkpoint file...")
87
-
88
- script_path = Path(__file__).resolve()
89
- config_path = script_path.parent / "train_configs" / "test_lora.yaml"
90
- with open(config_path, 'r') as file:
91
- config = yaml.safe_load(file)
92
-
93
- config = update_config(config, inputs)
94
- print("Config file is updated...", config)
95
- with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".yaml") as temp_file:
96
- yaml.dump(config, temp_file, default_flow_style=False)
97
- tmp_config_path = temp_file.name
98
-
99
- command = ["accelerate", "launch", "train_flux_lora_deepspeed.py", "--config", tmp_config_path]
100
- result = subprocess.run(command, check=True)
101
-
102
- # rRemove the temporary file after the command is run
103
- Path(tmp_config_path).unlink()
104
-
105
- return result
106
-
107
-
108
- def create_demo(
109
- model_type: str,
110
- device: str = "cuda" if torch.cuda.is_available() else "cpu",
111
- offload: bool = False,
112
- ckpt_dir: str = "",
113
- ):
114
- xflux_pipeline = XFluxPipeline(model_type, device, offload)
115
- checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors"))
116
-
117
- with gr.Blocks() as demo:
118
- gr.Markdown(f"# Flux Adapters by XLabs AI - Model: {model_type}")
119
- with gr.Tab("Inference"):
120
- with gr.Row():
121
- with gr.Column():
122
- prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
123
-
124
- with gr.Accordion("Generation Options", open=False):
125
- with gr.Row():
126
- width = gr.Slider(512, 2048, 1024, step=16, label="Width")
127
- height = gr.Slider(512, 2048, 1024, step=16, label="Height")
128
- neg_prompt = gr.Textbox(label="Negative Prompt", value="bad photo")
129
- with gr.Row():
130
- num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
131
- timestep_to_start_cfg = gr.Slider(1, 50, 1, step=1, label="timestep_to_start_cfg")
132
- with gr.Row():
133
- guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
134
- true_gs = gr.Slider(1.0, 5.0, 3.5, step=0.1, label="True Guidance", interactive=True)
135
- seed = gr.Textbox(-1, label="Seed (-1 for random)")
136
-
137
- with gr.Accordion("ControlNet Options", open=False):
138
- control_type = gr.Dropdown(["canny", "hed", "depth"], label="Control type")
139
- control_weight = gr.Slider(0.0, 1.0, 0.8, step=0.1, label="Controlnet weight", interactive=True)
140
- local_path = gr.Dropdown(checkpoints, label="Controlnet Checkpoint",
141
- info="Local Path to Controlnet weights (if no, it will be downloaded from HF)"
142
- )
143
- controlnet_image = gr.Image(label="Input Controlnet Image", visible=True, interactive=True)
144
-
145
- with gr.Accordion("LoRA Options", open=False):
146
- lora_weight = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="LoRA weight", interactive=True)
147
- lora_local_path = gr.Dropdown(
148
- checkpoints, label="LoRA Checkpoint", info="Local Path to Lora weights"
149
- )
150
-
151
- with gr.Accordion("IP Adapter Options", open=False):
152
- image_prompt = gr.Image(label="image_prompt", visible=True, interactive=True)
153
- ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="ip_scale")
154
- neg_image_prompt = gr.Image(label="neg_image_prompt", visible=True, interactive=True)
155
- neg_ip_scale = gr.Slider(0.0, 1.0, 1.0, step=0.1, label="neg_ip_scale")
156
- ip_local_path = gr.Dropdown(
157
- checkpoints, label="IP Adapter Checkpoint",
158
- info="Local Path to IP Adapter weights (if no, it will be downloaded from HF)"
159
- )
160
- generate_btn = gr.Button("Generate")
161
-
162
- with gr.Column():
163
- output_image = gr.Image(label="Generated Image")
164
- download_btn = gr.File(label="Download full-resolution")
165
-
166
- inputs = [prompt, image_prompt, controlnet_image, width, height, guidance,
167
- num_steps, seed, true_gs, ip_scale, neg_ip_scale, neg_prompt,
168
- neg_image_prompt, timestep_to_start_cfg, control_type, control_weight,
169
- lora_weight, local_path, lora_local_path, ip_local_path
170
- ]
171
- generate_btn.click(
172
- fn=xflux_pipeline.gradio_generate,
173
- inputs=inputs,
174
- outputs=[output_image, download_btn],
175
- )
176
-
177
- with gr.Tab("LoRA Finetuning"):
178
- data_dir = gr.Dropdown(list_train_data_dirs(),
179
- label="Training images (directory containing the training images)"
180
- )
181
- output_dir = gr.Textbox(label="Output Path", value="lora_checkpoint")
182
-
183
- with gr.Accordion("Training Options", open=True):
184
- lr = gr.Textbox(label="Learning Rate", value="1e-5")
185
- steps = gr.Slider(10000, 20000, 20000, step=100, label="Train Steps")
186
- rank = gr.Slider(1, 100, 16, step=1, label="LoRa Rank")
187
-
188
- training_btn = gr.Button("Start training")
189
- training_btn.click(
190
- fn=start_lora_training,
191
- inputs=[data_dir, output_dir, lr, steps, rank],
192
- outputs=[],
193
- )
194
-
195
-
196
- return demo
197
-
198
  if __name__ == "__main__":
199
- import argparse
200
- parser = argparse.ArgumentParser(description="Flux")
201
- parser.add_argument("--name", type=str, default="flux-dev", help="Model name")
202
- parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use")
203
- parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
204
- parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
205
- parser.add_argument("--ckpt_dir", type=str, default=".", help="Folder with checkpoints in safetensors format")
206
- args = parser.parse_args()
207
 
208
- demo = create_demo(args.name, args.device, args.offload, args.ckpt_dir)
209
- demo.launch(share=args.share)
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ import base64
5
+ import requests
6
  import os
7
+ import random
 
 
 
 
8
  import torch
9
+ import subprocess
10
+ import numpy as np
11
+ import cv2
12
+ from transformers import AutoProcessor, AutoModelForCausalLM
13
+ from diffusers import DiffusionPipeline
14
+ from datetime import datetime
15
+ from mistralai import Mistral
16
+ from theme import theme
17
+ from fastapi import FastAPI
18
+
19
+ app = FastAPI()
20
+
21
+
22
+
23
+ api_key = os.getenv("MISTRAL_API_KEY")
24
+ Mistralclient = Mistral(api_key=api_key)
25
+
26
+ def flip_image(x):
27
+ return np.fliplr(x)
28
+
29
+ def encode_image(image_path):
30
+ """Encode the image to base64."""
31
+ try:
32
+ # Open the image file
33
+ image = Image.open(image_path).convert("RGB")
34
+
35
+ # Resize the image to a height of 512 while maintaining the aspect ratio
36
+ base_height = 512
37
+ h_percent = (base_height / float(image.size[1]))
38
+ w_size = int((float(image.size[0]) * float(h_percent)))
39
+ image = image.resize((w_size, base_height), Image.LANCZOS)
40
+
41
+ # Convert the image to a byte stream
42
+ buffered = BytesIO()
43
+ image.save(buffered, format="JPEG")
44
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
45
+
46
+ return img_str
47
+ except FileNotFoundError:
48
+ print(f"Error: The file {image_path} was not found.")
49
+ return None
50
+ except Exception as e: # Add generic exception handling
51
+ print(f"Error: {e}")
52
+ return None
53
+
54
+ def feifeichat(image):
55
+ try:
56
+ model = "pixtral-large-2411"
57
+ # Define the messages for the chat
58
+ base64_image = encode_image(image)
59
+ messages = [{
60
+ "role":
61
+ "user",
62
+ "content": [
63
+ {
64
+ "type": "text",
65
+ "text": "Please provide a detailed description of this photo"
66
+ },
67
+ {
68
+ "type": "image_url",
69
+ "image_url": f"data:image/jpeg;base64,{base64_image}"
70
+ },
71
+ ],
72
+ "stream": False,
73
+ }]
74
+
75
+ partial_message = ""
76
+ for chunk in Mistralclient.chat.stream(model=model, messages=messages):
77
+ if chunk.data.choices[0].delta.content is not None:
78
+ partial_message = partial_message + chunk.data.choices[
79
+ 0].delta.content
80
+ yield partial_message
81
+ except Exception as e: # Add common exception handling
82
+ print(f"Error: {e}")
83
+ return "Please upload a photo"
84
+
85
+
86
+ with gr.Blocks(theme=theme, elem_id="app-container") as app:
87
+ gr.Markdown("Image To Flux Prompt")
88
+ with gr.Tab(label="Image To Prompt"):
89
+ with gr.Row():
90
+ with gr.Column():
91
+ input_img = gr.Image(label="Input Picture",height=320,type="filepath")
92
+ submit_btn = gr.Button(value="Submit", variant='primary')
93
+ with gr.Column():
94
+ output_text = gr.Textbox(label="Flux Prompt", show_copy_button = True)
95
+ clr_button =gr.Button("Clear",variant="primary", elem_id="clear_button")
96
+ clr_button.click(lambda: gr.Textbox(value=""), None, output_text)
97
+
98
+ submit_btn.click(feifeichat, [input_img], [output_text])
99
+
100
+
101
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  if __name__ == "__main__":
103
+ app.launch()
 
 
 
 
 
 
 
104