File size: 1,393 Bytes
7eafae4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import numpy as np
import gradio as gr

from PIL import Image
from models import dehazeformer


def infer(raw_image):
	network = dehazeformer()
	network.load_state_dict(torch.load('./saved_models/dehazeformer.pth', map_location=torch.device('cpu'))['state_dict'])
	# torch.save({'state_dict': network.state_dict()}, './saved_models/dehazeformer.pth')

	network.eval()
	
	image = np.array(raw_image, np.float32) / 255. * 2 - 1
	image = torch.from_numpy(image)
	image = image.permute((2, 0, 1)).unsqueeze(0)

	with torch.no_grad():
		output = network(image).clamp_(-1, 1)[0] * 0.5 + 0.5	
		output = output.permute((1, 2, 0))
		output = np.array(output, np.float32)
		output = np.round(output * 255.0)

	output = Image.fromarray(output.astype(np.uint8))

	return output


title = "DehazeFormer"
description = f"We use a mixed dataset to train the model, allowing the trained model to work better on real hazy images. To allow the model to process high-resolution images more efficiently and effectively, we extend it to the [MCT](https://github.com/IDKiro/MCT) variant."
examples = [
		["examples/1.jpg"],
		["examples/2.jpg"],
		["examples/3.jpg"],
		["examples/4.jpg"],
		["examples/5.jpg"],
		["examples/6.jpg"]
]

iface = gr.Interface(
	infer,
	inputs="image", outputs="image",
	title=title,
	description=description,
	allow_flagging='never',
	examples=examples,
)
iface.launch()