fixing bugs
Browse files
app.py
CHANGED
@@ -46,47 +46,55 @@ pipe.load_lora_weights(
|
|
46 |
peft_backend="peft"
|
47 |
)
|
48 |
|
49 |
-
|
|
|
50 |
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
|
51 |
|
52 |
-
|
53 |
# ========== Fluxo Integrado ==========
|
54 |
def full_pipeline(image, scale_factor, model_type, style_prompt):
|
55 |
-
# 1. Super-Resolution
|
56 |
sr_model = model_edsr if model_type == "EDSR" else model_rdn
|
57 |
sr_params = params_edsr if model_type == "EDSR" else params_rdn
|
58 |
-
sr_image = process(np.array(image) / 255., sr_model, sr_params,
|
59 |
-
(round(image.size[1] * scale_factor),
|
60 |
-
round(image.size[0] * scale_factor)),
|
61 |
-
True)
|
62 |
|
63 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture"
|
65 |
bas_relief = pipe(
|
66 |
prompt=prompt,
|
67 |
-
image=
|
68 |
strength=0.6,
|
69 |
num_inference_steps=25,
|
70 |
guidance_scale=7.5
|
71 |
).images[0]
|
72 |
|
73 |
-
# 3. Depth Map
|
74 |
-
inputs =
|
75 |
with torch.no_grad():
|
76 |
outputs = depth_model(**inputs)
|
77 |
depth = outputs.predicted_depth
|
78 |
|
79 |
depth = torch.nn.functional.interpolate(
|
80 |
depth.unsqueeze(1),
|
81 |
-
|
82 |
-
|
83 |
).squeeze().cpu().numpy()
|
84 |
|
85 |
-
depth = (depth - depth.min()) / (depth.max() - depth.min())
|
86 |
depth = (depth * 255).astype(np.uint8)
|
87 |
|
88 |
-
return
|
89 |
-
|
90 |
|
91 |
# ========== Interface Gradio ==========
|
92 |
with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
|
@@ -97,8 +105,10 @@ with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
|
|
97 |
input_image = gr.Image(label="Input Image", type="pil")
|
98 |
scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor")
|
99 |
model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model")
|
100 |
-
style_prompt = gr.Textbox(
|
101 |
-
|
|
|
|
|
102 |
process_btn = gr.Button("Start Pipeline")
|
103 |
|
104 |
with gr.Column():
|
|
|
46 |
peft_backend="peft"
|
47 |
)
|
48 |
|
49 |
+
# ========== Configuração do Modelo de Profundidade ==========
|
50 |
+
depth_processor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") # Nome padronizado
|
51 |
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
|
52 |
|
|
|
53 |
# ========== Fluxo Integrado ==========
|
54 |
def full_pipeline(image, scale_factor, model_type, style_prompt):
|
55 |
+
# 1. Super-Resolution (JAX)
|
56 |
sr_model = model_edsr if model_type == "EDSR" else model_rdn
|
57 |
sr_params = params_edsr if model_type == "EDSR" else params_rdn
|
|
|
|
|
|
|
|
|
58 |
|
59 |
+
# Processar e converter para numpy array
|
60 |
+
sr_jax = process(np.array(image) / 255., sr_model, sr_params,
|
61 |
+
(round(image.size[1] * scale_factor),
|
62 |
+
round(image.size[0] * scale_factor)),
|
63 |
+
True)
|
64 |
+
|
65 |
+
# Conversão crítica: JAX Array → numpy → PIL
|
66 |
+
sr_np = np.asarray(sr_jax)
|
67 |
+
sr_pil = Image.fromarray(sr_np)
|
68 |
+
|
69 |
+
if device == "cpu":
|
70 |
+
return sr_pil, None, None
|
71 |
+
|
72 |
+
# 2. Style Transfer (PyTorch)
|
73 |
prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture"
|
74 |
bas_relief = pipe(
|
75 |
prompt=prompt,
|
76 |
+
image=sr_pil, # Usar PIL Image diretamente
|
77 |
strength=0.6,
|
78 |
num_inference_steps=25,
|
79 |
guidance_scale=7.5
|
80 |
).images[0]
|
81 |
|
82 |
+
# 3. Depth Map
|
83 |
+
inputs = depth_processor(bas_relief, return_tensors="pt").to(device)
|
84 |
with torch.no_grad():
|
85 |
outputs = depth_model(**inputs)
|
86 |
depth = outputs.predicted_depth
|
87 |
|
88 |
depth = torch.nn.functional.interpolate(
|
89 |
depth.unsqueeze(1),
|
90 |
+
mode="bicubic",
|
91 |
+
size=bas_relief.size[::-1]
|
92 |
).squeeze().cpu().numpy()
|
93 |
|
94 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
|
95 |
depth = (depth * 255).astype(np.uint8)
|
96 |
|
97 |
+
return sr_pil, bas_relief, Image.fromarray(depth)
|
|
|
98 |
|
99 |
# ========== Interface Gradio ==========
|
100 |
with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
|
|
|
105 |
input_image = gr.Image(label="Input Image", type="pil")
|
106 |
scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor")
|
107 |
model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model")
|
108 |
+
style_prompt = gr.Textbox(
|
109 |
+
label="Style Prompt",
|
110 |
+
value="insanely detailed and complex engraving relief, ultra-high definition" # <-- Alteração aqui
|
111 |
+
)
|
112 |
process_btn = gr.Button("Start Pipeline")
|
113 |
|
114 |
with gr.Column():
|