nightfury commited on
Commit
b326d68
·
1 Parent(s): d9451cf

Update colorization.py

Browse files
Files changed (1) hide show
  1. colorization.py +66 -0
colorization.py CHANGED
@@ -22,6 +22,72 @@ def colorize_image(image):
22
 
23
  return f'./output/DeOldify/'+Path(image.name).stem+".png"
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def create_interface():
27
  with gr.Blocks() as enhancer:
 
22
 
23
  return f'./output/DeOldify/'+Path(image.name).stem+".png"
24
 
25
+ # def inference(img, version, scale, weight):
26
+ def inference(img, version, scale):
27
+ # weight /= 100
28
+ print(img, version, scale)
29
+ try:
30
+ extension = os.path.splitext(os.path.basename(str(img)))[1]
31
+ img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
32
+ if len(img.shape) == 3 and img.shape[2] == 4:
33
+ img_mode = 'RGBA'
34
+ elif len(img.shape) == 2: # for gray inputs
35
+ img_mode = None
36
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
37
+ else:
38
+ img_mode = None
39
+ h, w = img.shape[0:2]
40
+ if h < 300:
41
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
42
+
43
+ try:
44
+ # _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight)
45
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
46
+ except RuntimeError as error:
47
+ print('Error', error)
48
+ try:
49
+ if scale != 2:
50
+ interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
51
+ h, w = img.shape[0:2]
52
+ output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
53
+ except Exception as error:
54
+ print('wrong scale input.', error)
55
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
56
+ extension = 'png'
57
+ else:
58
+ extension = 'jpg'
59
+ save_path = f'output/out.{extension}'
60
+ cv2.imwrite(save_path, output)
61
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
62
+ return output, save_path
63
+ except Exception as error:
64
+ print('global exception', error)
65
+ return None, None
66
+
67
+
68
+
69
+ demo = gr.Interface(
70
+ inference, [
71
+ gr.inputs.Image(type="filepath", label="Input"),
72
+ # gr.inputs.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer'], type="value", default='v1.4', label='version'),
73
+ gr.inputs.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer','CodeFormer','RealESR-General-x4v3'], type="value", default='v1.4', label='version'),
74
+ gr.inputs.Number(label="Rescaling factor", default=2),
75
+ # gr.Slider(0, 100, label='Weight, only for CodeFormer. 0 for better quality, 100 for better identity', default=50)
76
+ ], [
77
+ gr.outputs.Image(type="numpy", label="Output (The whole image)"),
78
+ gr.outputs.File(label="Download the output image")
79
+ ],
80
+ title=title,
81
+ description=description,
82
+ article=article,
83
+ # examples=[['AI-generate.jpg', 'v1.4', 2, 50], ['lincoln.jpg', 'v1.4', 2, 50], ['Blake_Lively.jpg', 'v1.4', 2, 50],
84
+ # ['10045.png', 'v1.4', 2, 50]]).launch()
85
+ examples=[['a1.jpg', 'v1.4', 2], ['a2.jpg', 'v1.4', 2], ['a3.jpg', 'v1.4', 2],['a4.jpg', 'v1.4', 2]])
86
+
87
+ demo.queue(concurrency_count=4)
88
+ demo.launch()
89
+
90
+
91
 
92
  def create_interface():
93
  with gr.Blocks() as enhancer: