doevent commited on
Commit
352c33a
·
verified ·
1 Parent(s): 4916e36

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from KandiSuperRes import get_SR_pipeline
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+ sr_pipe2x = get_SR_pipeline(device=device, fp16=True, flash=True, scale=2)
8
+ sr_pipe4x = get_SR_pipeline(device=device, fp16=True, flash=True, scale=4)
9
+
10
+
11
+ def inference(image, size):
12
+ if image is None:
13
+ raise gr.Error("Image not uploaded")
14
+ r_image = Image.open(image)
15
+
16
+ if size == '2x':
17
+ result = sr_pipe2x(lr_image)
18
+ else:
19
+ result = sr_pipe4x(lr_image)
20
+
21
+ print(f"Image size ({device}): {size} ... OK")
22
+ return result
23
+
24
+
25
+ title = "KandiSuperRes - diffusion model for super resolution"
26
+ description = "KandiSuperRes Flash is a new version of the diffusion model for super resolution. This model includes a distilled version of the KandiSuperRes model and a distilled model Kandinsky 3.0 Flash. KandiSuperRes Flash not only improves image clarity, but also corrects artifacts, draws details, improves image aesthetics. And one of the most important advantages is the ability to use the model in the "infinite super resolution" mode."
27
+ article = "<div style='text-align: center;'>Twitter <a href='https://twitter.com/DoEvent' target='_blank'>Max Skobeev</a> | <a href='https://huggingface.co/ai-forever/KandiSuperRes' target='_blank'>Model card</a><div>"
28
+
29
+
30
+ gr.Interface(inference,
31
+ [gr.Image(type="pil"),
32
+ gr.Radio(['2x', '4x'],
33
+ type="value",
34
+ value='2x',
35
+ label='Resolution model')],
36
+ gr.Image(type="pil", label="Output"),
37
+ title=title,
38
+ description=description,
39
+ article=article,
40
+ examples=[['groot.jpeg', "2x"]],
41
+ allow_flagging='never',
42
+ cache_examples=False,
43
+ ).queue(api_open=True).launch(show_error=True, show_api=True)
44
+