Update app.py
Browse files
app.py
CHANGED
@@ -13,30 +13,14 @@ from tqdm.auto import tqdm
|
|
13 |
import random
|
14 |
import gradio as gr
|
15 |
|
16 |
-
|
17 |
-
import torchvision.transforms as transforms
|
18 |
-
|
19 |
-
def tensor_to_pil(tensor):
|
20 |
-
# Assuming tensor is normalized
|
21 |
-
unnormalize = transforms.Normalize(mean=[-0.5 / 0.5, -0.5 / 0.5, -0.5 / 0.5], std=[1/0.5, 1/0.5, 1/0.5])
|
22 |
-
tensor = unnormalize(tensor)
|
23 |
-
tensor = tensor.clamp(0, 1)
|
24 |
-
tensor = tensor.permute(1, 2, 0) # Convert from CxHxW to HxWxC
|
25 |
-
image = Image.fromarray((tensor.numpy() * 255).astype('uint8'))
|
26 |
-
return image
|
27 |
-
|
28 |
-
|
29 |
def generate_images(prompt, guidance_scale, n_samples, num_inference_steps):
|
30 |
seeds = [random.randint(1, 10000) for _ in range(n_samples)]
|
31 |
images = []
|
32 |
for seed in tqdm(seeds):
|
33 |
torch.manual_seed(seed)
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
pil_image = tensor_to_pil(tensor_image)
|
38 |
-
images.append(pil_image)
|
39 |
-
return images
|
40 |
|
41 |
def gr_generate_images(prompt: str, num_images = 1, num_inference = 20, guidance_scale = 8 ):
|
42 |
prompt = prompt + "sks style"
|
|
|
13 |
import random
|
14 |
import gradio as gr
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def generate_images(prompt, guidance_scale, n_samples, num_inference_steps):
|
17 |
seeds = [random.randint(1, 10000) for _ in range(n_samples)]
|
18 |
images = []
|
19 |
for seed in tqdm(seeds):
|
20 |
torch.manual_seed(seed)
|
21 |
+
image = pipe(prompt, num_inference_steps=num_inference_steps,guidance_scale=guidance_scale).images[0]
|
22 |
+
images.append(image)
|
23 |
+
return images[0]
|
|
|
|
|
|
|
24 |
|
25 |
def gr_generate_images(prompt: str, num_images = 1, num_inference = 20, guidance_scale = 8 ):
|
26 |
prompt = prompt + "sks style"
|