|
import gradio as gr |
|
import os |
|
import PIL |
|
from PIL import Image |
|
from pathlib import Path |
|
import numpy as np |
|
import numpy.random as npr |
|
from contextlib import nullcontext |
|
|
|
import torch |
|
import torchvision.transforms as tvtrans |
|
from lib.cfg_helper import model_cfg_bank |
|
from lib.model_zoo import get_model |
|
|
|
n_sample_image_default = 2 |
|
n_sample_text_default = 4 |
|
cache_examples = True |
|
hfm_repo_id = 'shi-labs/versatile-diffusion-model' |
|
hfm_filename = 'pretrained_pth/vd-four-flow-v1-0-fp16.pth' |
|
|
|
def highlight_print(info): |
|
print('') |
|
print(''.join(['#']*(len(info)+4))) |
|
print('# '+info+' #') |
|
print(''.join(['#']*(len(info)+4))) |
|
print('') |
|
|
|
class color_adjust(object): |
|
def __init__(self, ref_from, ref_to): |
|
x0, m0, std0 = self.get_data_and_stat(ref_from) |
|
x1, m1, std1 = self.get_data_and_stat(ref_to) |
|
self.ref_from_stat = (m0, std0) |
|
self.ref_to_stat = (m1, std1) |
|
self.ref_from = self.preprocess(x0).reshape(-1, 3) |
|
self.ref_to = x1.reshape(-1, 3) |
|
|
|
def get_data_and_stat(self, x): |
|
if isinstance(x, str): |
|
x = np.array(PIL.Image.open(x)) |
|
elif isinstance(x, PIL.Image.Image): |
|
x = np.array(x) |
|
elif isinstance(x, torch.Tensor): |
|
x = torch.clamp(x, min=0.0, max=1.0) |
|
x = np.array(tvtrans.ToPILImage()(x)) |
|
elif isinstance(x, np.ndarray): |
|
pass |
|
else: |
|
raise ValueError |
|
x = x.astype(float) |
|
m = np.reshape(x, (-1, 3)).mean(0) |
|
s = np.reshape(x, (-1, 3)).std(0) |
|
return x, m, s |
|
|
|
def preprocess(self, x): |
|
m0, s0 = self.ref_from_stat |
|
m1, s1 = self.ref_to_stat |
|
y = ((x-m0)/s0)*s1 + m1 |
|
return y |
|
|
|
def __call__(self, xin, keep=0, simple=False): |
|
xin, _, _ = self.get_data_and_stat(xin) |
|
x = self.preprocess(xin) |
|
if simple: |
|
y = (x*(1-keep) + xin*keep) |
|
y = np.clip(y, 0, 255).astype(np.uint8) |
|
return y |
|
|
|
h, w = x.shape[:2] |
|
x = x.reshape(-1, 3) |
|
y = [] |
|
for chi in range(3): |
|
yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi]) |
|
y.append(yi) |
|
|
|
y = np.stack(y, axis=1) |
|
y = y.reshape(h, w, 3) |
|
y = (y.astype(float)*(1-keep) + xin.astype(float)*keep) |
|
y = np.clip(y, 0, 255).astype(np.uint8) |
|
return y |
|
|
|
def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600): |
|
arr = np.concatenate((arr_fo, arr_to)) |
|
min_v = arr.min() - 1e-6 |
|
max_v = arr.max() + 1e-6 |
|
min_vto = arr_to.min() - 1e-6 |
|
max_vto = arr_to.max() + 1e-6 |
|
xs = np.array( |
|
[min_v + (max_v - min_v) * i / n for i in range(n + 1)]) |
|
hist_fo, _ = np.histogram(arr_fo, xs) |
|
hist_to, _ = np.histogram(arr_to, xs) |
|
xs = xs[:-1] |
|
|
|
cum_fo = np.cumsum(hist_fo) |
|
cum_to = np.cumsum(hist_to) |
|
d_fo = cum_fo / cum_fo[-1] |
|
d_to = cum_to / cum_to[-1] |
|
|
|
t_d = np.interp(d_fo, d_to, xs) |
|
t_d[d_fo <= d_to[ 0]] = min_vto |
|
t_d[d_fo >= d_to[-1]] = max_vto |
|
arr_out = np.interp(arr_in, xs, t_d) |
|
return arr_out |
|
|
|
class vd_inference(object): |
|
def __init__(self, pth=None, hfm_repo=None, fp16=False, device=0): |
|
cfgm_name = 'vd_noema' |
|
cfgm = model_cfg_bank()('vd_noema') |
|
net = get_model()(cfgm) |
|
if fp16: |
|
highlight_print('Running in FP16') |
|
net.clip.fp16 = True |
|
net = net.half() |
|
if pth is not None: |
|
sd = torch.load(pth, map_location='cpu') |
|
print('Load pretrained weight from {}'.format(pth)) |
|
else: |
|
from huggingface_hub import hf_hub_download |
|
temppath = hf_hub_download(hfm_repo[0], hfm_repo[1]) |
|
sd = torch.load(temppath, map_location='cpu') |
|
print('Load pretrained weight from {}/{}'.format(*hfm_repo)) |
|
|
|
net.load_state_dict(sd, strict=False) |
|
net.to(device) |
|
|
|
self.device = device |
|
self.model_name = cfgm_name |
|
self.net = net |
|
self.fp16 = fp16 |
|
from lib.model_zoo.ddim_vd import DDIMSampler_VD |
|
self.sampler = DDIMSampler_VD(net) |
|
|
|
def regularize_image(self, x): |
|
BICUBIC = PIL.Image.Resampling.BICUBIC |
|
if isinstance(x, str): |
|
x = Image.open(x).resize([512, 512], resample=BICUBIC) |
|
x = tvtrans.ToTensor()(x) |
|
elif isinstance(x, PIL.Image.Image): |
|
x = x.resize([512, 512], resample=BICUBIC) |
|
x = tvtrans.ToTensor()(x) |
|
elif isinstance(x, np.ndarray): |
|
x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC) |
|
x = tvtrans.ToTensor()(x) |
|
elif isinstance(x, torch.Tensor): |
|
pass |
|
else: |
|
assert False, 'Unknown image type' |
|
|
|
assert (x.shape[1]==512) & (x.shape[2]==512), \ |
|
'Wrong image size' |
|
x = x.to(self.device) |
|
if self.fp16: |
|
x = x.half() |
|
return x |
|
|
|
def decode(self, z, xtype, ctype, color_adj='None', color_adj_to=None): |
|
net = self.net |
|
if xtype == 'image': |
|
x = net.autokl_decode(z) |
|
|
|
color_adj_flag = (color_adj!='none') and (color_adj!='None') and (color_adj is not None) |
|
color_adj_simple = (color_adj=='Simple') or color_adj=='simple' |
|
color_adj_keep_ratio = 0.5 |
|
|
|
if color_adj_flag and (ctype=='vision'): |
|
x_adj = [] |
|
for xi in x: |
|
color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to) |
|
xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple) |
|
x_adj.append(xi_adj) |
|
x = x_adj |
|
else: |
|
x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0) |
|
x = [tvtrans.ToPILImage()(xi) for xi in x] |
|
return x |
|
|
|
elif xtype == 'text': |
|
prompt_temperature = 1.0 |
|
prompt_merge_same_adj_word = True |
|
x = net.optimus_decode(z, temperature=prompt_temperature) |
|
if prompt_merge_same_adj_word: |
|
xnew = [] |
|
for xi in x: |
|
xi_split = xi.split() |
|
xinew = [] |
|
for idxi, wi in enumerate(xi_split): |
|
if idxi!=0 and wi==xi_split[idxi-1]: |
|
continue |
|
xinew.append(wi) |
|
xnew.append(' '.join(xinew)) |
|
x = xnew |
|
return x |
|
|
|
def inference(self, xtype, cin, ctype, scale=7.5, n_samples=None, color_adj=None,): |
|
net = self.net |
|
sampler = self.sampler |
|
ddim_steps = 50 |
|
ddim_eta = 0.0 |
|
|
|
if xtype == 'image': |
|
n_samples = n_sample_image_default if n_samples is None else n_samples |
|
elif xtype == 'text': |
|
n_samples = n_sample_text_default if n_samples is None else n_samples |
|
|
|
if ctype in ['prompt', 'text']: |
|
c = net.clip_encode_text(n_samples * [cin]) |
|
u = None |
|
if scale != 1.0: |
|
u = net.clip_encode_text(n_samples * [""]) |
|
|
|
elif ctype in ['vision', 'image']: |
|
cin = self.regularize_image(cin) |
|
ctemp = cin*2 - 1 |
|
ctemp = ctemp[None].repeat(n_samples, 1, 1, 1) |
|
c = net.clip_encode_vision(ctemp) |
|
u = None |
|
if scale != 1.0: |
|
dummy = torch.zeros_like(ctemp) |
|
u = net.clip_encode_vision(dummy) |
|
|
|
u, c = [u.half(), c.half()] if self.fp16 else [u, c] |
|
|
|
if xtype == 'image': |
|
h, w = [512, 512] |
|
shape = [n_samples, 4, h//8, w//8] |
|
z, _ = sampler.sample( |
|
steps=ddim_steps, |
|
shape=shape, |
|
conditioning=c, |
|
unconditional_guidance_scale=scale, |
|
unconditional_conditioning=u, |
|
xtype=xtype, ctype=ctype, |
|
eta=ddim_eta, |
|
verbose=False,) |
|
x = self.decode(z, xtype, ctype, color_adj=color_adj, color_adj_to=cin) |
|
return x |
|
|
|
elif xtype == 'text': |
|
n = 768 |
|
shape = [n_samples, n] |
|
z, _ = sampler.sample( |
|
steps=ddim_steps, |
|
shape=shape, |
|
conditioning=c, |
|
unconditional_guidance_scale=scale, |
|
unconditional_conditioning=u, |
|
xtype=xtype, ctype=ctype, |
|
eta=ddim_eta, |
|
verbose=False,) |
|
x = self.decode(z, xtype, ctype) |
|
return x |
|
|
|
def application_disensemble(self, cin, n_samples=None, level=0, color_adj=None,): |
|
net = self.net |
|
scale = 7.5 |
|
sampler = self.sampler |
|
ddim_steps = 50 |
|
ddim_eta = 0.0 |
|
n_samples = n_sample_image_default if n_samples is None else n_samples |
|
|
|
cin = self.regularize_image(cin) |
|
ctemp = cin*2 - 1 |
|
ctemp = ctemp[None].repeat(n_samples, 1, 1, 1) |
|
c = net.clip_encode_vision(ctemp) |
|
u = None |
|
if scale != 1.0: |
|
dummy = torch.zeros_like(ctemp) |
|
u = net.clip_encode_vision(dummy) |
|
u, c = [u.half(), c.half()] if self.fp16 else [u, c] |
|
|
|
if level == 0: |
|
pass |
|
else: |
|
c_glb = c[:, 0:1] |
|
c_loc = c[:, 1: ] |
|
u_glb = u[:, 0:1] |
|
u_loc = u[:, 1: ] |
|
|
|
if level == -1: |
|
c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=1) |
|
u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=1) |
|
if level == -2: |
|
c_loc = self.remove_low_rank(c_loc, demean=True, q=50, q_remove=2) |
|
u_loc = self.remove_low_rank(u_loc, demean=True, q=50, q_remove=2) |
|
if level == 1: |
|
c_loc = self.find_low_rank(c_loc, demean=True, q=10) |
|
u_loc = self.find_low_rank(u_loc, demean=True, q=10) |
|
if level == 2: |
|
c_loc = self.find_low_rank(c_loc, demean=True, q=2) |
|
u_loc = self.find_low_rank(u_loc, demean=True, q=2) |
|
|
|
c = torch.cat([c_glb, c_loc], dim=1) |
|
u = torch.cat([u_glb, u_loc], dim=1) |
|
|
|
h, w = [512, 512] |
|
shape = [n_samples, 4, h//8, w//8] |
|
z, _ = sampler.sample( |
|
steps=ddim_steps, |
|
shape=shape, |
|
conditioning=c, |
|
unconditional_guidance_scale=scale, |
|
unconditional_conditioning=u, |
|
xtype='image', ctype='vision', |
|
eta=ddim_eta, |
|
verbose=False,) |
|
x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=cin) |
|
return x |
|
|
|
def find_low_rank(self, x, demean=True, q=20, niter=10): |
|
if demean: |
|
x_mean = x.mean(-1, keepdim=True) |
|
x_input = x - x_mean |
|
else: |
|
x_input = x |
|
|
|
if x_input.dtype == torch.float16: |
|
fp16 = True |
|
x_input = x_input.float() |
|
else: |
|
fp16 = False |
|
|
|
u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter) |
|
ss = torch.stack([torch.diag(si) for si in s]) |
|
x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1])) |
|
|
|
if fp16: |
|
x_lowrank = x_lowrank.half() |
|
|
|
if demean: |
|
x_lowrank += x_mean |
|
return x_lowrank |
|
|
|
def remove_low_rank(self, x, demean=True, q=20, niter=10, q_remove=10): |
|
if demean: |
|
x_mean = x.mean(-1, keepdim=True) |
|
x_input = x - x_mean |
|
else: |
|
x_input = x |
|
|
|
if x_input.dtype == torch.float16: |
|
fp16 = True |
|
x_input = x_input.float() |
|
else: |
|
fp16 = False |
|
|
|
u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter) |
|
s[:, 0:q_remove] = 0 |
|
ss = torch.stack([torch.diag(si) for si in s]) |
|
x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1])) |
|
|
|
if fp16: |
|
x_lowrank = x_lowrank.half() |
|
|
|
if demean: |
|
x_lowrank += x_mean |
|
return x_lowrank |
|
|
|
def application_dualguided(self, cim, ctx, n_samples=None, mixing=0.5, color_adj=None, ): |
|
net = self.net |
|
scale = 7.5 |
|
sampler = self.sampler |
|
ddim_steps = 50 |
|
ddim_eta = 0.0 |
|
n_samples = n_sample_image_default if n_samples is None else n_samples |
|
|
|
ctemp0 = self.regularize_image(cim) |
|
ctemp1 = ctemp0*2 - 1 |
|
ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1) |
|
cim = net.clip_encode_vision(ctemp1) |
|
uim = None |
|
if scale != 1.0: |
|
dummy = torch.zeros_like(ctemp1) |
|
uim = net.clip_encode_vision(dummy) |
|
|
|
ctx = net.clip_encode_text(n_samples * [ctx]) |
|
utx = None |
|
if scale != 1.0: |
|
utx = net.clip_encode_text(n_samples * [""]) |
|
|
|
uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim] |
|
utx, ctx = [utx.half(), ctx.half()] if self.fp16 else [utx, ctx] |
|
|
|
h, w = [512, 512] |
|
shape = [n_samples, 4, h//8, w//8] |
|
|
|
z, _ = sampler.sample_dc( |
|
steps=ddim_steps, |
|
shape=shape, |
|
first_conditioning=[uim, cim], |
|
second_conditioning=[utx, ctx], |
|
unconditional_guidance_scale=scale, |
|
xtype='image', |
|
first_ctype='vision', |
|
second_ctype='prompt', |
|
eta=ddim_eta, |
|
verbose=False, |
|
mixed_ratio=(1-mixing), ) |
|
x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0) |
|
return x |
|
|
|
def application_i2t2i(self, cim, ctx_n, ctx_p, n_samples=None, color_adj=None,): |
|
net = self.net |
|
scale = 7.5 |
|
sampler = self.sampler |
|
ddim_steps = 50 |
|
ddim_eta = 0.0 |
|
prompt_temperature = 1.0 |
|
n_samples = n_sample_image_default if n_samples is None else n_samples |
|
|
|
ctemp0 = self.regularize_image(cim) |
|
ctemp1 = ctemp0*2 - 1 |
|
ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1) |
|
cim = net.clip_encode_vision(ctemp1) |
|
uim = None |
|
if scale != 1.0: |
|
dummy = torch.zeros_like(ctemp1) |
|
uim = net.clip_encode_vision(dummy) |
|
|
|
uim, cim = [uim.half(), cim.half()] if self.fp16 else [uim, cim] |
|
|
|
n = 768 |
|
shape = [n_samples, n] |
|
zt, _ = sampler.sample( |
|
steps=ddim_steps, |
|
shape=shape, |
|
conditioning=cim, |
|
unconditional_guidance_scale=scale, |
|
unconditional_conditioning=uim, |
|
xtype='text', ctype='vision', |
|
eta=ddim_eta, |
|
verbose=False,) |
|
ztn = net.optimus_encode([ctx_n]) |
|
ztp = net.optimus_encode([ctx_p]) |
|
|
|
ztn_norm = ztn / ztn.norm(dim=1) |
|
zt_proj_mag = torch.matmul(zt, ztn_norm[0]) |
|
zt_perp = zt - zt_proj_mag[:, None] * ztn_norm |
|
zt_newd = zt_perp + ztp |
|
ctx_new = net.optimus_decode(zt_newd, temperature=prompt_temperature) |
|
|
|
ctx_new = net.clip_encode_text(ctx_new) |
|
ctx_p = net.clip_encode_text([ctx_p]) |
|
ctx_new = torch.cat([ctx_new, ctx_p.repeat(n_samples, 1, 1)], dim=1) |
|
utx_new = net.clip_encode_text(n_samples * [""]) |
|
utx_new = torch.cat([utx_new, utx_new], dim=1) |
|
|
|
cim_loc = cim[:, 1: ] |
|
cim_loc_new = self.find_low_rank(cim_loc, demean=True, q=10) |
|
cim_new = cim_loc_new |
|
uim_new = uim[:, 1:] |
|
|
|
h, w = [512, 512] |
|
shape = [n_samples, 4, h//8, w//8] |
|
z, _ = sampler.sample_dc( |
|
steps=ddim_steps, |
|
shape=shape, |
|
first_conditioning=[uim_new, cim_new], |
|
second_conditioning=[utx_new, ctx_new], |
|
unconditional_guidance_scale=scale, |
|
xtype='image', |
|
first_ctype='vision', |
|
second_ctype='prompt', |
|
eta=ddim_eta, |
|
verbose=False, |
|
mixed_ratio=0.33, ) |
|
|
|
x = self.decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=ctemp0) |
|
return x |
|
|
|
vd_inference = vd_inference(hfm_repo=[hfm_repo_id, hfm_filename], fp16=True, device='cuda') |
|
|
|
def main(mode, |
|
image=None, |
|
prompt=None, |
|
nprompt=None, |
|
pprompt=None, |
|
color_adj=None, |
|
disentanglement_level=None, |
|
dual_guided_mixing=None, |
|
seed=0,): |
|
|
|
if seed<0: |
|
seed = 0 |
|
np.random.seed(seed) |
|
torch.manual_seed(seed+100) |
|
|
|
if mode == 'Text-to-Image': |
|
if (prompt is None) or (prompt == ""): |
|
return None, None |
|
with torch.no_grad(): |
|
rv = vd_inference.inference( |
|
xtype = 'image', |
|
cin = prompt, |
|
ctype = 'prompt', ) |
|
return rv, None |
|
elif mode == 'Image-Variation': |
|
if image is None: |
|
return None, None |
|
with torch.no_grad(): |
|
rv = vd_inference.inference( |
|
xtype = 'image', |
|
cin = image, |
|
ctype = 'vision', |
|
color_adj = color_adj,) |
|
return rv, None |
|
elif mode == 'Image-to-Text': |
|
if image is None: |
|
return None, None |
|
with torch.no_grad(): |
|
rv = vd_inference.inference( |
|
xtype = 'text', |
|
cin = image, |
|
ctype = 'vision',) |
|
return None, '\n'.join(rv) |
|
elif mode == 'Text-Variation': |
|
if prompt is None: |
|
return None, None |
|
with torch.no_grad(): |
|
rv = vd_inference.inference( |
|
xtype = 'text', |
|
cin = prompt, |
|
ctype = 'prompt',) |
|
return None, '\n'.join(rv) |
|
elif mode == 'Disentanglement': |
|
if image is None: |
|
return None, None |
|
with torch.no_grad(): |
|
rv = vd_inference.application_disensemble( |
|
cin = image, |
|
level = disentanglement_level, |
|
color_adj = color_adj,) |
|
return rv, None |
|
elif mode == 'Dual-Guided': |
|
if (image is None) or (prompt is None) or (prompt==""): |
|
return None, None |
|
with torch.no_grad(): |
|
rv = vd_inference.application_dualguided( |
|
cim = image, |
|
ctx = prompt, |
|
mixing = dual_guided_mixing, |
|
color_adj = color_adj,) |
|
return rv, None |
|
elif mode == 'Latent-I2T2I': |
|
if (image is None) or (nprompt is None) or (nprompt=="") \ |
|
or (pprompt is None) or (pprompt==""): |
|
return None, None |
|
with torch.no_grad(): |
|
rv = vd_inference.application_i2t2i( |
|
cim = image, |
|
ctx_n = nprompt, |
|
ctx_p = pprompt, |
|
color_adj = color_adj,) |
|
return rv, None |
|
else: |
|
assert False, "No such mode!" |
|
|
|
def get_instruction(mode): |
|
t2i_instruction = ["Generate image from text prompt."] |
|
i2i_instruction = [ |
|
"Generate image conditioned on reference image.", |
|
"Color Calibration provide an opinion to adjust image color according to reference image.", ] |
|
i2t_instruction = ["Generate text from reference image."] |
|
t2t_instruction = ["Generate text from reference text prompt. (Model insufficiently trained, thus results are still experimental)"] |
|
dis_instruction = [ |
|
"Generate a variation of reference image that disentangled for semantic or style.", |
|
"Color Calibration provide an opinion to adjust image color according to reference image.", |
|
"Disentanglement level controls the level of focus towards semantic (-2, -1) or style (1 2). Level 0 serves as Image-Variation.", ] |
|
dug_instruction = [ |
|
"Generate image from dual guidance of reference image and text prompt.", |
|
"Color Calibration provide an opinion to adjust image color according to reference image.", |
|
"Guidance Mixing provides linear balances between image and text context. (0 towards image, 1 towards text)", ] |
|
iti_instruction = [ |
|
"Generate image variations via image-to-text, text-latent-editing, and then text-to-image. (Still under exploration)", |
|
"Color Calibration provide an opinion to adjust image color according to reference image.", |
|
"Input prompt that will be substract from text/text latent code.", |
|
"Input prompt that will be added to text/text latent code.", ] |
|
|
|
if mode == "Text-to-Image": |
|
return '\n'.join(t2i_instruction) |
|
elif mode == "Image-Variation": |
|
return '\n'.join(i2i_instruction) |
|
elif mode == "Image-to-Text": |
|
return '\n'.join(i2t_instruction) |
|
elif mode == "Text-Variation": |
|
return '\n'.join(t2t_instruction) |
|
elif mode == "Disentanglement": |
|
return '\n'.join(dis_instruction) |
|
elif mode == "Dual-Guided": |
|
return '\n'.join(dug_instruction) |
|
elif mode == "Latent-I2T2I": |
|
return '\n'.join(iti_instruction) |
|
|
|
|
|
|
|
|
|
|
|
if True: |
|
img_output = gr.Gallery(label="Image Result").style(grid=n_sample_image_default) |
|
txt_output = gr.Textbox(lines=4, label='Text Result', visible=False) |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML( |
|
""" |
|
<div style="position: relative; float: left; text-align: center; width: 60%; min-width:600px; height: 160px; margin: 20px 0 20px 20%;"> |
|
<h1 style="font-weight: 900; font-size: 3rem;"> |
|
Versatile Diffusion |
|
</h1> |
|
<br> |
|
<h2 style="font-weight: 450; font-size: 1rem;"> |
|
We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>. |
|
VD can natively support image-to-text, image-variation, text-to-image, and text-variation, |
|
and can be further extended to other applications such as |
|
semantic-style disentanglement, image-text dual-guided generation, latent image-to-text-to-image editing, and more. |
|
Future versions will support more modalities such as speech, music, video and 3D. |
|
</h2> |
|
<br> |
|
<h3>Xingqian Xu, Atlas Wang, Eric Zhang, Kai Wang, |
|
and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a> |
|
[<a href="https://arxiv.org/abs/2211.08332" style="color:blue;">arXiv</a>] |
|
[<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>] |
|
</h3> |
|
</div> |
|
<div style="position: relative; float: right; width: 19.9%; min-width:200px; margin: 20px auto;"> |
|
<img src="https://huggingface.co/spaces/shi-labs/Versatile-Diffusion/resolve/main/assets/figures/share_instruction.png"> |
|
</div> |
|
""") |
|
mode_input = gr.Radio([ |
|
"Text-to-Image", "Image-Variation", "Image-to-Text", "Text-Variation", |
|
"Disentanglement", "Dual-Guided", "Latent-I2T2I"], value='Text-to-Image', label="VD Flows and Applications") |
|
|
|
instruction = gr.Textbox(get_instruction("Text-to-Image"), label='Info') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
img_input = gr.Image(label='Image Input', visible=False) |
|
txt_input = gr.Textbox(lines=4, placeholder="Input prompt...", label='Text Input') |
|
ntxt_input = gr.Textbox(label='Remove Prompt', visible=False) |
|
ptxt_input = gr.Textbox(label='Add Prompt', visible=False) |
|
coladj_input = gr.Radio(["None", "Simple"], value='Simple', label="Color Calibration", visible=False) |
|
dislvl_input = gr.Slider(-2, 2, value=0, step=1, label="Disentanglement level", visible=False) |
|
dguide_input = gr.Slider(0, 1, value=0.5, step=0.01, label="Guidance Mixing", visible=False) |
|
seed_input = gr.Number(100, label="Seed", precision=0) |
|
|
|
btn = gr.Button("Run") |
|
btn.click( |
|
main, |
|
inputs=[ |
|
mode_input, |
|
img_input, |
|
txt_input, |
|
ntxt_input, |
|
ptxt_input, |
|
coladj_input, |
|
dislvl_input, |
|
dguide_input, |
|
seed_input, ], |
|
outputs=[img_output, txt_output]) |
|
|
|
with gr.Column(): |
|
img_output.render() |
|
txt_output.render() |
|
|
|
example_mode = [ |
|
"Text-to-Image", |
|
"Image-Variation", |
|
"Image-to-Text", |
|
"Text-Variation", |
|
"Disentanglement", |
|
"Dual-Guided", |
|
"Latent-I2T2I"] |
|
|
|
def get_example(mode): |
|
if mode == 'Text-to-Image': |
|
case = [ |
|
['a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ', 23], |
|
['a beautiful grand nebula in the universe', 24], |
|
['heavy arms gundam penguin mech', 25], |
|
] |
|
elif mode == "Image-Variation": |
|
case = [ |
|
['assets/space.jpg', 'None', 26], |
|
['assets/train.jpg', 'Simple', 27], |
|
] |
|
elif mode == "Image-to-Text": |
|
case = [ |
|
['assets/boy_and_girl.jpg' , 28], |
|
['assets/house_by_lake.jpg', 29], |
|
] |
|
elif mode == "Text-Variation": |
|
case = [ |
|
['a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ' , 32], |
|
['a beautiful grand nebula in the universe' , 33], |
|
['heavy arms gundam penguin mech', 34], |
|
] |
|
elif mode == "Disentanglement": |
|
case = [ |
|
['assets/vermeer.jpg', 'Simple', -2, 30], |
|
['assets/matisse.jpg', 'Simple', 2, 31], |
|
] |
|
elif mode == "Dual-Guided": |
|
case = [ |
|
['assets/benz.jpg', 'cyberpunk 2077', 'Simple', 0.75, 22], |
|
['assets/vermeer.jpg', 'a girl with a diamond necklace', 'Simple', 0.66, 21], |
|
] |
|
elif mode == "Latent-I2T2I": |
|
case = [ |
|
['assets/ghibli.jpg', 'white house', 'tall castle', 'Simple', 20], |
|
['assets/matisse.jpg', 'fruits and bottles on the table', 'flowers on the table', 'Simple', 21], |
|
] |
|
else: |
|
raise ValueError |
|
case = [[mode] + casei for casei in case] |
|
return case |
|
|
|
def get_example_iof(mode): |
|
if mode == 'Text-to-Image': |
|
inps = [txt_input, seed_input] |
|
oups = [img_output] |
|
fn = lambda m, x, y: \ |
|
main(mode=m, prompt=x, seed=y)[0] |
|
elif mode == "Image-Variation": |
|
inps = [img_input, coladj_input, seed_input] |
|
oups = [img_output] |
|
fn = lambda m, x, y, z: \ |
|
main(mode=m, image=x, color_adj=y, seed=z)[0] |
|
elif mode == "Image-to-Text": |
|
inps = [img_input, seed_input] |
|
oups = [txt_output] |
|
fn = lambda m, x, y: \ |
|
main(mode=m, image=x, seed=y)[1] |
|
elif mode == "Text-Variation": |
|
inps = [txt_input, seed_input] |
|
oups = [txt_output] |
|
fn = lambda m, x, y: \ |
|
main(mode=m, prompt=x, seed=y)[1] |
|
elif mode == "Disentanglement": |
|
inps = [img_input, coladj_input, dislvl_input, seed_input] |
|
oups = [img_output] |
|
fn = lambda m, x, y, z, w: \ |
|
main(mode=m, image=x, color_adj=y, disentanglement_level=z, seed=w)[0] |
|
elif mode == "Dual-Guided": |
|
inps = [img_input, txt_input, coladj_input, dguide_input, seed_input] |
|
oups = [img_output] |
|
fn = lambda m, x, y, z, w, u: \ |
|
main(mode=m, image=x, prompt=y, color_adj=z, dual_guided_mixing=w, seed=u)[0] |
|
elif mode == "Latent-I2T2I": |
|
inps = [img_input, ntxt_input, ptxt_input, coladj_input, seed_input] |
|
oups = [img_output] |
|
fn = lambda m, x, y, z, w, u: \ |
|
main(mode=m, image=x, nprompt=y, pprompt=z, color_adj=w, seed=u)[0] |
|
else: |
|
raise ValueError |
|
return [mode_input]+inps, oups, fn |
|
|
|
with gr.Row(): |
|
for emode in example_mode[0:4]: |
|
with gr.Column(): |
|
gr.Examples( |
|
label=emode+' Examples', |
|
examples=get_example(emode), |
|
inputs=get_example_iof(emode)[0], |
|
outputs=get_example_iof(emode)[1], |
|
fn = get_example_iof(emode)[2], |
|
cache_examples=cache_examples), |
|
with gr.Row(): |
|
for emode in example_mode[4:7]: |
|
with gr.Column(): |
|
gr.Examples( |
|
label=emode+' Examples', |
|
examples=get_example(emode), |
|
inputs=get_example_iof(emode)[0], |
|
outputs=get_example_iof(emode)[1], |
|
fn = get_example_iof(emode)[2], |
|
cache_examples=cache_examples), |
|
|
|
mode_input.change( |
|
fn=lambda x: gr.update(value=get_instruction(x)), |
|
inputs=mode_input, |
|
outputs=instruction,) |
|
|
|
mode_input.change( |
|
fn=lambda x: gr.update(visible=(x not in ['Text-to-Image', 'Text-Variation'])), |
|
inputs=mode_input, |
|
outputs=img_input,) |
|
|
|
mode_input.change( |
|
fn=lambda x: gr.update(visible=(x in ['Text-to-Image', 'Text-Variation', 'Dual-Guided'])), |
|
inputs=mode_input, |
|
outputs=txt_input,) |
|
|
|
mode_input.change( |
|
fn=lambda x: gr.update(visible=(x in ['Latent-I2T2I'])), |
|
inputs=mode_input, |
|
outputs=ntxt_input,) |
|
mode_input.change( |
|
fn=lambda x: gr.update(visible=(x in ['Latent-I2T2I'])), |
|
inputs=mode_input, |
|
outputs=ptxt_input,) |
|
|
|
mode_input.change( |
|
fn=lambda x: gr.update(visible=(x not in ['Text-to-Image', 'Image-to-Text', 'Text-Variation'])), |
|
inputs=mode_input, |
|
outputs=coladj_input,) |
|
|
|
mode_input.change( |
|
fn=lambda x: gr.update(visible=(x=='Disentanglement')), |
|
inputs=mode_input, |
|
outputs=dislvl_input,) |
|
|
|
mode_input.change( |
|
fn=lambda x: gr.update(visible=(x=='Dual-Guided')), |
|
inputs=mode_input, |
|
outputs=dguide_input,) |
|
|
|
mode_input.change( |
|
fn=lambda x: gr.update(visible=(x not in ['Image-to-Text', 'Text-Variation'])), |
|
inputs=mode_input, |
|
outputs=img_output,) |
|
mode_input.change( |
|
fn=lambda x: gr.update(visible=(x in ['Image-to-Text', 'Text-Variation'])), |
|
inputs=mode_input, |
|
outputs=txt_output,) |
|
|
|
gr.HTML( |
|
""" |
|
<div style="text-align: center; max-width: 1200px; margin: 20px auto;"> |
|
<h3> |
|
<b>Caution</b>: |
|
We would like the raise the awareness of users of this demo of its potential issues and concerns. |
|
Like previous large foundation models, Versatile Diffusion could be problematic in some cases, partially due to the imperfect training data and pretrained network (VAEs / context encoders) with limited scope. |
|
In its future research phase, VD may do better on tasks such as text-to-image, image-to-text, etc., with the help of more powerful VAEs, more sophisticated network designs, and more cleaned data. |
|
So far, we keep all features available for research testing both to show the great potential of the VD framework and to collect important feedback to improve the model in the future. |
|
We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors. |
|
</h3> |
|
<br> |
|
<h3> |
|
<b>Biases and content acknowledgement</b>: |
|
Beware that VD may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence. |
|
VD was trained on the LAION-2B dataset, which scraped non-curated online images and text, and may contained unintended exceptions as we removed illegal content. |
|
VD in this demo is meant only for research purposes. |
|
</h3> |
|
</div> |
|
""") |
|
|
|
|
|
demo.launch(debug=True) |
|
|