ds1david commited on
Commit
8c7829e
·
1 Parent(s): 1f384c6

fixing bugs

Browse files
Files changed (1) hide show
  1. app.py +28 -18
app.py CHANGED
@@ -46,47 +46,55 @@ pipe.load_lora_weights(
46
  peft_backend="peft"
47
  )
48
 
49
- feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
 
50
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
51
 
52
-
53
  # ========== Fluxo Integrado ==========
54
  def full_pipeline(image, scale_factor, model_type, style_prompt):
55
- # 1. Super-Resolution
56
  sr_model = model_edsr if model_type == "EDSR" else model_rdn
57
  sr_params = params_edsr if model_type == "EDSR" else params_rdn
58
- sr_image = process(np.array(image) / 255., sr_model, sr_params,
59
- (round(image.size[1] * scale_factor),
60
- round(image.size[0] * scale_factor)),
61
- True)
62
 
63
- # 2. Bas-Relief Style Transfer
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture"
65
  bas_relief = pipe(
66
  prompt=prompt,
67
- image=sr_image,
68
  strength=0.6,
69
  num_inference_steps=25,
70
  guidance_scale=7.5
71
  ).images[0]
72
 
73
- # 3. Depth Map Estimation
74
- inputs = feature_extractor(bas_relief, return_tensors="pt").to(device)
75
  with torch.no_grad():
76
  outputs = depth_model(**inputs)
77
  depth = outputs.predicted_depth
78
 
79
  depth = torch.nn.functional.interpolate(
80
  depth.unsqueeze(1),
81
- size=bas_relief.size[::-1],
82
- mode="bicubic"
83
  ).squeeze().cpu().numpy()
84
 
85
- depth = (depth - depth.min()) / (depth.max() - depth.min())
86
  depth = (depth * 255).astype(np.uint8)
87
 
88
- return sr_image, bas_relief, Image.fromarray(depth)
89
-
90
 
91
  # ========== Interface Gradio ==========
92
  with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
@@ -97,8 +105,10 @@ with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
97
  input_image = gr.Image(label="Input Image", type="pil")
98
  scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor")
99
  model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model")
100
- style_prompt = gr.Textbox(label="Style Prompt",
101
- placeholder="marble sculpture, ancient greek style")
 
 
102
  process_btn = gr.Button("Start Pipeline")
103
 
104
  with gr.Column():
 
46
  peft_backend="peft"
47
  )
48
 
49
+ # ========== Configuração do Modelo de Profundidade ==========
50
+ depth_processor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") # Nome padronizado
51
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
52
 
 
53
  # ========== Fluxo Integrado ==========
54
  def full_pipeline(image, scale_factor, model_type, style_prompt):
55
+ # 1. Super-Resolution (JAX)
56
  sr_model = model_edsr if model_type == "EDSR" else model_rdn
57
  sr_params = params_edsr if model_type == "EDSR" else params_rdn
 
 
 
 
58
 
59
+ # Processar e converter para numpy array
60
+ sr_jax = process(np.array(image) / 255., sr_model, sr_params,
61
+ (round(image.size[1] * scale_factor),
62
+ round(image.size[0] * scale_factor)),
63
+ True)
64
+
65
+ # Conversão crítica: JAX Array → numpy → PIL
66
+ sr_np = np.asarray(sr_jax)
67
+ sr_pil = Image.fromarray(sr_np)
68
+
69
+ if device == "cpu":
70
+ return sr_pil, None, None
71
+
72
+ # 2. Style Transfer (PyTorch)
73
  prompt = f"BAS-RELIEF {style_prompt}, intricate carving, marble texture"
74
  bas_relief = pipe(
75
  prompt=prompt,
76
+ image=sr_pil, # Usar PIL Image diretamente
77
  strength=0.6,
78
  num_inference_steps=25,
79
  guidance_scale=7.5
80
  ).images[0]
81
 
82
+ # 3. Depth Map
83
+ inputs = depth_processor(bas_relief, return_tensors="pt").to(device)
84
  with torch.no_grad():
85
  outputs = depth_model(**inputs)
86
  depth = outputs.predicted_depth
87
 
88
  depth = torch.nn.functional.interpolate(
89
  depth.unsqueeze(1),
90
+ mode="bicubic",
91
+ size=bas_relief.size[::-1]
92
  ).squeeze().cpu().numpy()
93
 
94
+ depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)
95
  depth = (depth * 255).astype(np.uint8)
96
 
97
+ return sr_pil, bas_relief, Image.fromarray(depth)
 
98
 
99
  # ========== Interface Gradio ==========
100
  with gr.Blocks(title="TheraSR + Bas-Relief Fusion") as app:
 
105
  input_image = gr.Image(label="Input Image", type="pil")
106
  scale = gr.Slider(1.0, 4.0, value=2.0, label="Scale Factor")
107
  model_type = gr.Radio(["EDSR", "RDN"], value="EDSR", label="SR Model")
108
+ style_prompt = gr.Textbox(
109
+ label="Style Prompt",
110
+ value="insanely detailed and complex engraving relief, ultra-high definition" # <-- Alteração aqui
111
+ )
112
  process_btn = gr.Button("Start Pipeline")
113
 
114
  with gr.Column():