atlury commited on
Commit
53ce97d
·
verified ·
1 Parent(s): 21f63e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -6,12 +6,13 @@ from PIL import Image
6
  import numpy as np
7
 
8
  import gradio as gr
9
- from model import IAT
10
 
11
  def set_example_image(example: list) -> dict:
12
  return gr.Image.update(value=example[0])
13
 
14
  def tensor_to_numpy(tensor):
 
15
  tensor = tensor.detach().cpu().numpy()
16
  if tensor.ndim == 3 and tensor.shape[0] == 3: # Convert CHW to HWC
17
  tensor = tensor.transpose(1, 2, 0)
@@ -19,6 +20,7 @@ def tensor_to_numpy(tensor):
19
  return tensor
20
 
21
  def dark_inference(img):
 
22
  model = IAT()
23
  checkpoint_file_path = './checkpoint/best_Epoch_lol.pth'
24
  state_dict = torch.load(checkpoint_file_path, map_location='cpu')
@@ -33,15 +35,17 @@ def dark_inference(img):
33
  ConvertImageDtype(torch.float)
34
  ])
35
  input_img = transform(img)
36
- print(f'Image shape: {input_img.shape}')
37
 
38
  with torch.no_grad():
39
  enhanced_img = model(input_img.unsqueeze(0))
40
 
41
  result_img = tensor_to_numpy(enhanced_img[0])
 
42
  return result_img
43
 
44
  def exposure_inference(img):
 
45
  model = IAT()
46
  checkpoint_file_path = './checkpoint/best_Epoch_exposure.pth'
47
  state_dict = torch.load(checkpoint_file_path, map_location='cpu')
@@ -56,12 +60,13 @@ def exposure_inference(img):
56
  ConvertImageDtype(torch.float)
57
  ])
58
  input_img = transform(img)
59
- print(f'Image shape: {input_img.shape}')
60
 
61
  with torch.no_grad():
62
  enhanced_img = model(input_img.unsqueeze(0))
63
 
64
  result_img = tensor_to_numpy(enhanced_img[0])
 
65
  return result_img
66
 
67
  demo = gr.Blocks()
 
6
  import numpy as np
7
 
8
  import gradio as gr
9
+ from iatenhancement.model import IAT # Ensure the correct import path
10
 
11
  def set_example_image(example: list) -> dict:
12
  return gr.Image.update(value=example[0])
13
 
14
  def tensor_to_numpy(tensor):
15
+ print("Converting tensor to numpy array...")
16
  tensor = tensor.detach().cpu().numpy()
17
  if tensor.ndim == 3 and tensor.shape[0] == 3: # Convert CHW to HWC
18
  tensor = tensor.transpose(1, 2, 0)
 
20
  return tensor
21
 
22
  def dark_inference(img):
23
+ print("Starting dark inference...")
24
  model = IAT()
25
  checkpoint_file_path = './checkpoint/best_Epoch_lol.pth'
26
  state_dict = torch.load(checkpoint_file_path, map_location='cpu')
 
35
  ConvertImageDtype(torch.float)
36
  ])
37
  input_img = transform(img)
38
+ print(f'Image shape after transform: {input_img.shape}')
39
 
40
  with torch.no_grad():
41
  enhanced_img = model(input_img.unsqueeze(0))
42
 
43
  result_img = tensor_to_numpy(enhanced_img[0])
44
+ print("Dark inference completed.")
45
  return result_img
46
 
47
  def exposure_inference(img):
48
+ print("Starting exposure inference...")
49
  model = IAT()
50
  checkpoint_file_path = './checkpoint/best_Epoch_exposure.pth'
51
  state_dict = torch.load(checkpoint_file_path, map_location='cpu')
 
60
  ConvertImageDtype(torch.float)
61
  ])
62
  input_img = transform(img)
63
+ print(f'Image shape after transform: {input_img.shape}')
64
 
65
  with torch.no_grad():
66
  enhanced_img = model(input_img.unsqueeze(0))
67
 
68
  result_img = tensor_to_numpy(enhanced_img[0])
69
+ print("Exposure inference completed.")
70
  return result_img
71
 
72
  demo = gr.Blocks()