nightfury commited on
Commit
a8d228d
·
1 Parent(s): 83939f5

Update colorization.py

Browse files
Files changed (1) hide show
  1. colorization.py +10 -9
colorization.py CHANGED
@@ -6,12 +6,10 @@ from pathlib import Path
6
  from datetime import datetime
7
  from typing import Optional
8
 
9
-
10
  model = hub.Module(name='deoldify')
11
  # NOTE: Max is 45 with 11GB video cards. 35 is a good default
12
  render_factor=35
13
 
14
-
15
  def colorize_image(image):
16
  # now = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
17
  if not os.path.exists("./output"):
@@ -19,13 +17,12 @@ def colorize_image(image):
19
  # if image is not None:
20
  # image.save(f"./output/{now}-input.jpg")
21
  model.predict(image.name)
22
-
23
  return f'./output/DeOldify/'+Path(image.name).stem+".png"
24
 
25
  # def inference(img, version, scale, weight):
26
- def inference(img, version, scale):
27
  # weight /= 100
28
- print(img, version, scale)
29
  try:
30
  extension = os.path.splitext(os.path.basename(str(img)))[1]
31
  img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
@@ -45,6 +42,7 @@ def inference(img, version, scale):
45
  _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
46
  except RuntimeError as error:
47
  print('Error', error)
 
48
  try:
49
  if scale != 2:
50
  interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
@@ -52,22 +50,25 @@ def inference(img, version, scale):
52
  output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
53
  except Exception as error:
54
  print('wrong scale input.', error)
 
55
  if img_mode == 'RGBA': # RGBA images should be saved in png format
56
  extension = 'png'
57
  else:
58
  extension = 'jpg'
59
- save_path = f'output/out.{extension}'
 
 
 
60
  cv2.imwrite(save_path, output)
61
- output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
62
  return output, save_path
63
  except Exception as error:
64
  print('global exception', error)
65
  return None, None
66
 
67
 
68
-
69
  colorize = gr.Interface(
70
- inference, [
71
  gr.Markdown("Colorize old black & white photos"),
72
  gr.Image(type="filepath", label="Input"),
73
 
 
6
  from datetime import datetime
7
  from typing import Optional
8
 
 
9
  model = hub.Module(name='deoldify')
10
  # NOTE: Max is 45 with 11GB video cards. 35 is a good default
11
  render_factor=35
12
 
 
13
  def colorize_image(image):
14
  # now = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
15
  if not os.path.exists("./output"):
 
17
  # if image is not None:
18
  # image.save(f"./output/{now}-input.jpg")
19
  model.predict(image.name)
 
20
  return f'./output/DeOldify/'+Path(image.name).stem+".png"
21
 
22
  # def inference(img, version, scale, weight):
23
+ def inferenceColorize(img, version, scale):
24
  # weight /= 100
25
+ print(img, scale)
26
  try:
27
  extension = os.path.splitext(os.path.basename(str(img)))[1]
28
  img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
 
42
  _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
43
  except RuntimeError as error:
44
  print('Error', error)
45
+
46
  try:
47
  if scale != 2:
48
  interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
 
50
  output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
51
  except Exception as error:
52
  print('wrong scale input.', error)
53
+
54
  if img_mode == 'RGBA': # RGBA images should be saved in png format
55
  extension = 'png'
56
  else:
57
  extension = 'jpg'
58
+
59
+ model.predict(img.name)
60
+ save_path = f'output/DeOldify/'+Path(img.name).stem+'.{extension}'
61
+
62
  cv2.imwrite(save_path, output)
63
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
64
  return output, save_path
65
  except Exception as error:
66
  print('global exception', error)
67
  return None, None
68
 
69
 
 
70
  colorize = gr.Interface(
71
+ inferenceColorize, [
72
  gr.Markdown("Colorize old black & white photos"),
73
  gr.Image(type="filepath", label="Input"),
74