mobenta commited on
Commit
e9ad477
·
verified ·
1 Parent(s): 7f6036f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (6).py +152 -0
  2. requirements (1).txt +11 -0
app (6).py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
6
+ from diffusers import DiffusionPipeline
7
+ import random
8
+ import numpy as np
9
+ import os
10
+ import subprocess
11
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
+
13
+ # Initialize models
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ dtype = torch.bfloat16
16
+
17
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
18
+
19
+ # FLUX.1-dev model
20
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, token = huggingface_token).to(device)
21
+
22
+ # Initialize Florence model
23
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
24
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
25
+
26
+ # Prompt Enhancer
27
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
28
+
29
+ MAX_SEED = np.iinfo(np.int32).max
30
+ MAX_IMAGE_SIZE = 2048
31
+
32
+ # Florence caption function
33
+ def florence_caption(image):
34
+ # Convert image to PIL if it's not already
35
+ if not isinstance(image, Image.Image):
36
+ image = Image.fromarray(image)
37
+
38
+ inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
39
+ generated_ids = florence_model.generate(
40
+ input_ids=inputs["input_ids"],
41
+ pixel_values=inputs["pixel_values"],
42
+ max_new_tokens=1024,
43
+ early_stopping=False,
44
+ do_sample=False,
45
+ num_beams=3,
46
+ )
47
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
48
+ parsed_answer = florence_processor.post_process_generation(
49
+ generated_text,
50
+ task="<MORE_DETAILED_CAPTION>",
51
+ image_size=(image.width, image.height)
52
+ )
53
+ return parsed_answer["<MORE_DETAILED_CAPTION>"]
54
+
55
+ # Prompt Enhancer function
56
+ def enhance_prompt(input_prompt):
57
+ result = enhancer_long("Enhance the description: " + input_prompt)
58
+ enhanced_text = result[0]['summary_text']
59
+ return enhanced_text
60
+
61
+ @spaces.GPU(duration=190)
62
+ def process_workflow(image, text_prompt, use_enhancer, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
63
+ if image is not None:
64
+ # Convert image to PIL if it's not already
65
+ if not isinstance(image, Image.Image):
66
+ image = Image.fromarray(image)
67
+
68
+ prompt = florence_caption(image)
69
+ else:
70
+ prompt = text_prompt
71
+
72
+ if use_enhancer:
73
+ prompt = enhance_prompt(prompt)
74
+
75
+ if randomize_seed:
76
+ seed = random.randint(0, MAX_SEED)
77
+
78
+ generator = torch.Generator(device=device).manual_seed(seed)
79
+
80
+ image = pipe(
81
+ prompt=prompt,
82
+ generator=generator,
83
+ num_inference_steps=num_inference_steps,
84
+ width=width,
85
+ height=height,
86
+ guidance_scale=guidance_scale
87
+ ).images[0]
88
+
89
+ return image, prompt, seed
90
+
91
+ custom_css = """
92
+ .input-group, .output-group {
93
+ border: 1px solid #e0e0e0;
94
+ border-radius: 10px;
95
+ padding: 20px;
96
+ margin-bottom: 20px;
97
+ background-color: #f9f9f9;
98
+ }
99
+ .submit-btn {
100
+ background-color: #2980b9 !important;
101
+ color: white !important;
102
+ }
103
+ .submit-btn:hover {
104
+ background-color: #3498db !important;
105
+ }
106
+ """
107
+
108
+ title = """<h1 align="center">FLUX.1-dev with Florence-2 Captioner and Prompt Enhancer</h1>
109
+ <p><center>
110
+ <a href="https://huggingface.co/black-forest-labs/FLUX.1-dev" target="_blank">[FLUX.1-dev Model]</a>
111
+ <a href="https://huggingface.co/microsoft/Florence-2-base" target="_blank">[Florence-2 Model]</a>
112
+ <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
113
+ <p align="center">Create long prompts from images or enhance your short prompts with prompt enhancer</p>
114
+ </center></p>
115
+ """
116
+
117
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
118
+ gr.HTML(title)
119
+
120
+ with gr.Row():
121
+ with gr.Column(scale=1):
122
+ with gr.Group(elem_classes="input-group"):
123
+ input_image = gr.Image(label="Input Image (Florence-2 Captioner)")
124
+
125
+ with gr.Accordion("Advanced Settings", open=False):
126
+ text_prompt = gr.Textbox(label="Text Prompt (optional, used if no image is uploaded)")
127
+ use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
128
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
129
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
130
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
131
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
132
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=3.5)
133
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
134
+
135
+ generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
136
+
137
+ with gr.Column(scale=1):
138
+ with gr.Group(elem_classes="output-group"):
139
+ output_image = gr.Image(label="Result", elem_id="gallery", show_label=False)
140
+ final_prompt = gr.Textbox(label="Final Prompt Used")
141
+ used_seed = gr.Number(label="Seed Used")
142
+
143
+ generate_btn.click(
144
+ fn=process_workflow,
145
+ inputs=[
146
+ input_image, text_prompt, use_enhancer, seed, randomize_seed,
147
+ width, height, guidance_scale, num_inference_steps
148
+ ],
149
+ outputs=[output_image, final_prompt, used_seed]
150
+ )
151
+
152
+ demo.launch(debug=True)
requirements (1).txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spaces
2
+ huggingface_hub
3
+ accelerate
4
+ git+https://github.com/huggingface/diffusers.git
5
+ torch==2.2.0
6
+ torchvision==0.17.0
7
+ transformers==4.42.4
8
+ xformers
9
+ sentencepiece
10
+ timm
11
+ einops