Spaces:
Runtime error
Runtime error
Om-Alve
commited on
Commit
·
319f6be
1
Parent(s):
a459d13
downscaling
Browse files
app.py
CHANGED
@@ -35,7 +35,7 @@ class StyleTransfer(nn.Module):
|
|
35 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
36 |
|
37 |
def image_merger(content, style,beta=10,device=device):
|
38 |
-
size =
|
39 |
alpha = 1
|
40 |
beta *= 1000
|
41 |
content = Image.fromarray(content)
|
@@ -52,7 +52,7 @@ def image_merger(content, style,beta=10,device=device):
|
|
52 |
generator = StyleTransfer().to(device).eval()
|
53 |
opt = torch.optim.Adam([generated],lr=0.06)
|
54 |
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=5, gamma=0.9) # Learning rate scheduler
|
55 |
-
num_epochs = 30 if device
|
56 |
style_features,_ = generator(style)
|
57 |
_,content_features = generator(content)
|
58 |
loop = tqdm(range(num_epochs),leave=False)
|
@@ -74,7 +74,7 @@ def image_merger(content, style,beta=10,device=device):
|
|
74 |
total_loss.backward(retain_graph=True)
|
75 |
opt.step()
|
76 |
scheduler.step()
|
77 |
-
if total_loss < 200 and device
|
78 |
break
|
79 |
print(total_loss.item())
|
80 |
img = np.array(generated.cpu().detach().squeeze(0).permute(1,2,0))
|
|
|
35 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
36 |
|
37 |
def image_merger(content, style,beta=10,device=device):
|
38 |
+
size = 300
|
39 |
alpha = 1
|
40 |
beta *= 1000
|
41 |
content = Image.fromarray(content)
|
|
|
52 |
generator = StyleTransfer().to(device).eval()
|
53 |
opt = torch.optim.Adam([generated],lr=0.06)
|
54 |
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=5, gamma=0.9) # Learning rate scheduler
|
55 |
+
num_epochs = 30 if device != "cuda" else 100
|
56 |
style_features,_ = generator(style)
|
57 |
_,content_features = generator(content)
|
58 |
loop = tqdm(range(num_epochs),leave=False)
|
|
|
74 |
total_loss.backward(retain_graph=True)
|
75 |
opt.step()
|
76 |
scheduler.step()
|
77 |
+
if total_loss < 200 and device!='cuda':
|
78 |
break
|
79 |
print(total_loss.item())
|
80 |
img = np.array(generated.cpu().detach().squeeze(0).permute(1,2,0))
|