|
import torch |
|
from PIL import Image |
|
|
|
import random |
|
import pandas as pd |
|
import gradio as gr |
|
import numpy as np |
|
from sklearn.linear_model import LogisticRegression |
|
from sklearn.svm import SVC |
|
from sklearn import preprocessing |
|
import time |
|
import torch |
|
from matplotlib import pyplot as plt |
|
|
|
from model import model, tokenizer, load_image |
|
|
|
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler |
|
from huggingface_hub import hf_hub_download |
|
from safetensors.torch import load_file |
|
|
|
device = 'cuda' |
|
dtype = torch.bfloat16 |
|
|
|
base = "stabilityai/stable-diffusion-xl-base-1.0" |
|
repo = "ByteDance/SDXL-Lightning" |
|
ckpt = "sdxl_lightning_8step_unet.safetensors" |
|
|
|
|
|
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype) |
|
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device)) |
|
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device) |
|
|
|
|
|
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") |
|
|
|
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(True, dtype): |
|
|
|
pixel_values = load_image(image_file='blank.png', max_num=1).to(device) |
|
base_embed = model.extract_feature(pixel_values.to(dtype)).detach().float() |
|
|
|
|
|
|
|
def get_text(embed): |
|
with torch.cuda.amp.autocast(True, dtype): |
|
generation_config = dict(max_new_tokens=32, do_sample=True, |
|
temperature=.5, top_p=.92) |
|
|
|
|
|
pixel_values = 0 |
|
question = '''''' |
|
response = model.chat(tokenizer, pixel_values, question, generation_config, visual_features=embed.to(dtype)) |
|
print(response) |
|
return response |
|
|
|
def get_image(text): |
|
return pipe(text, num_inference_steps=8, guidance_scale=0).images[0] |
|
|
|
def get_embed(img): |
|
with torch.cuda.amp.autocast(True, dtype): |
|
|
|
pixel_values = load_image(image_file='', pil_image=img, max_num=1).to(device) |
|
embed = model.extract_feature(pixel_values.to(dtype)) |
|
return embed.float() |
|
|
|
|
|
|
|
prompt_list = [p for p in list(set( |
|
pd.read_csv('/home/ryn_mote/Misc/twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] |
|
random.shuffle(prompt_list) |
|
|
|
|
|
|
|
NOT_calibrate_prompts = [ |
|
'an abstract painting', |
|
'unique streetwear design that blends the old with the new. Combine bold, urban typography with retro graphics, taking inspiration from distressed signage and graffiti. Use a range of earthy tones to give the design a vintage aesthetic, while adding a modern twist with a stylistic rendering of the graphics', |
|
'a photo of hell', |
|
'' |
|
] |
|
|
|
calibrate_prompts = [ |
|
"4k photo", |
|
'surrealist art', |
|
'a psychedelic, fractal view', |
|
'a beautiful collage', |
|
'an intricate portrait', |
|
'an impressionist painting', |
|
'abstract art', |
|
'an eldritch image', |
|
'a sketch', |
|
'a city full of darkness and graffiti', |
|
'a black & white photo', |
|
'a brilliant, timeless tarot card of the world', |
|
'''eternity: a timeless, vivid painted portrait by ryan murdock''', |
|
'''a simple, timeless, & dark charcoal on canvas: death itself by ryan murdock''', |
|
'''a painted image with gorgeous red gradients: Persephone by ryan murdock''', |
|
'''a simple, timeless, & dark photo with gorgeous gradients: last night of my life by ryan murdock''', |
|
'''the sunflower -- a dark, simple painted still life by ryan murdock''', |
|
'''silence in the macrocosm -- a dark, intricate painting by ryan murdock''', |
|
'''beauty here -- a photograph by ryan murdock''', |
|
'''a timeless, haunting portrait: the necrotic jester''', |
|
'''a simple, timeless, & dark art piece with gorgeous gradients: serenity''', |
|
'''an elegant image of nature with gorgeous swirling gradients''', |
|
'''simple, timeless digital art with gorgeous purple spirals''', |
|
'''timeless digital art with gorgeous gradients: eternal slumber''', |
|
'''a simple, timeless image with gorgeous gradients''', |
|
'''a simple, timeless painted image of nature with beautiful gradients''', |
|
'a timeless, dark digital art piece with gorgeous gradients: the hanged man', |
|
'', |
|
] |
|
|
|
|
|
|
|
global_idx = 0 |
|
embs = [] |
|
ys = [] |
|
|
|
start_time = time.time() |
|
|
|
def next_image(): |
|
with torch.no_grad(): |
|
if len(calibrate_prompts) > 0: |
|
prompt = calibrate_prompts.pop(0) |
|
print(f'######### Calibrating with sample: {prompt} #########') |
|
|
|
image = get_image(prompt) |
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
embed = get_embed(image) |
|
|
|
|
|
embs.append(embed) |
|
return image, prompt |
|
else: |
|
print('######### Roaming #########') |
|
|
|
|
|
indices = range(len(ys)) |
|
pos_indices = [i for i in indices if ys[i] > .5] |
|
neg_indices = [i for i in indices if ys[i] <= .5] |
|
|
|
mini = min(len(pos_indices), len(neg_indices)) |
|
|
|
if mini < 1: |
|
feature_embs = torch.stack([torch.randn(1280), torch.randn(1280)]) |
|
ys_t = [0, 1] |
|
print('Not enough ratings.') |
|
else: |
|
|
|
ys_t = [ys[i] for i in indices] |
|
feature_embs = torch.stack([embs[e][0, 0].detach().cpu() for e in indices]).squeeze() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(np.array(feature_embs).shape, np.array(ys_t).shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pos_sol = torch.stack([feature_embs[i] for i in range(len(ys_t)) if ys_t[i] > .5]).mean(0, keepdim=True).to(device, dtype) |
|
neg_sol = torch.stack([feature_embs[i] for i in range(len(ys_t)) if ys_t[i] < .5]).mean(0, keepdim=True).to(device, dtype) |
|
|
|
|
|
latest_pos = (random.sample([feature_embs[i] for i in range(len(ys_t)) if ys_t[i] > .5], 1)[0]).to(device, dtype) |
|
|
|
dif = pos_sol - neg_sol |
|
sol = latest_pos + ((dif / dif.std()) * latest_pos.std()) |
|
|
|
print(sol.shape) |
|
|
|
|
|
text = get_text(sol) |
|
image = get_image(text) |
|
embed = get_embed(image) |
|
|
|
embs.append(embed) |
|
|
|
plt.close() |
|
plt.hist(sol.detach().cpu().float().flatten()) |
|
plt.savefig('sol.jpg') |
|
|
|
|
|
plt.close() |
|
plt.hist(embed.detach().cpu().float().flatten()) |
|
plt.savefig('embed.jpg') |
|
|
|
|
|
return image, text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def start(_): |
|
return [ |
|
gr.Button(value='Like', interactive=True), |
|
gr.Button(value='Neither', interactive=True), |
|
gr.Button(value='Dislike', interactive=True), |
|
gr.Button(value='Start', interactive=False), |
|
*next_image() |
|
] |
|
|
|
|
|
def choose(choice): |
|
global global_idx |
|
global_idx += 1 |
|
if choice == 'Like': |
|
choice = 1 |
|
elif choice == 'Neither': |
|
_ = embs.pop(-1) |
|
return next_image() |
|
else: |
|
choice = 0 |
|
ys.append(choice) |
|
return next_image() |
|
|
|
css = "div#output-image {height: 512px !important; width: 512px !important; margin:auto;}" |
|
with gr.Blocks(css=css) as demo: |
|
with gr.Row(): |
|
html = gr.HTML('''<div style='text-align:center; font-size:32'>You will callibrate for several prompts and then roam.</ div>''') |
|
with gr.Row(elem_id='output-image'): |
|
img = gr.Image(interactive=False, elem_id='output-image',) |
|
with gr.Row(elem_id='output-txt'): |
|
txt = gr.Textbox(interactive=False, elem_id='output-txt',) |
|
with gr.Row(equal_height=True): |
|
b3 = gr.Button(value='Dislike', interactive=False,) |
|
b2 = gr.Button(value='Neither', interactive=False,) |
|
b1 = gr.Button(value='Like', interactive=False,) |
|
b1.click( |
|
choose, |
|
[b1], |
|
[img, txt] |
|
) |
|
b2.click( |
|
choose, |
|
[b2], |
|
[img, txt] |
|
) |
|
b3.click( |
|
choose, |
|
[b3], |
|
[img, txt] |
|
) |
|
with gr.Row(): |
|
b4 = gr.Button(value='Start') |
|
b4.click(start, |
|
[b4], |
|
[b1, b2, b3, b4, img, txt]) |
|
|
|
demo.launch(share=True) |
|
|
|
|
|
|
|
|
|
|
|
|