Update app.py
Browse files
app.py
CHANGED
@@ -13,13 +13,29 @@ from tqdm.auto import tqdm
|
|
13 |
import random
|
14 |
import gradio as gr
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def generate_images(prompt, guidance_scale, n_samples, num_inference_steps):
|
17 |
seeds = [random.randint(1, 10000) for _ in range(n_samples)]
|
18 |
images = []
|
19 |
for seed in tqdm(seeds):
|
20 |
torch.manual_seed(seed)
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
return images
|
24 |
|
25 |
def gr_generate_images(prompt: str, num_images = 1, num_inference = 20, guidance_scale = 8 ):
|
|
|
13 |
import random
|
14 |
import gradio as gr
|
15 |
|
16 |
+
|
17 |
+
import torchvision.transforms as transforms
|
18 |
+
|
19 |
+
def tensor_to_pil(tensor):
|
20 |
+
# Assuming tensor is normalized
|
21 |
+
unnormalize = transforms.Normalize(mean=[-0.5 / 0.5, -0.5 / 0.5, -0.5 / 0.5], std=[1/0.5, 1/0.5, 1/0.5])
|
22 |
+
tensor = unnormalize(tensor)
|
23 |
+
tensor = tensor.clamp(0, 1)
|
24 |
+
tensor = tensor.permute(1, 2, 0) # Convert from CxHxW to HxWxC
|
25 |
+
image = Image.fromarray((tensor.numpy() * 255).astype('uint8'))
|
26 |
+
return image
|
27 |
+
|
28 |
+
|
29 |
def generate_images(prompt, guidance_scale, n_samples, num_inference_steps):
|
30 |
seeds = [random.randint(1, 10000) for _ in range(n_samples)]
|
31 |
images = []
|
32 |
for seed in tqdm(seeds):
|
33 |
torch.manual_seed(seed)
|
34 |
+
#tensor = pipe(prompt, num_inference_steps=num_inference_steps,guidance_scale=guidance_scale).images[0]
|
35 |
+
tensor_image = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images[0]
|
36 |
+
# Convert tensor to PIL Image
|
37 |
+
pil_image = tensor_to_pil(tensor_image)
|
38 |
+
images.append(pil_image)
|
39 |
return images
|
40 |
|
41 |
def gr_generate_images(prompt: str, num_images = 1, num_inference = 20, guidance_scale = 8 ):
|