AYYasaswini commited on
Commit
719784d
·
verified ·
1 Parent(s): a71c93c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -24
app.py CHANGED
@@ -905,30 +905,29 @@ def generate_embed_style(prompt, learned_style, seed):
905
 
906
  def generate_image_from_prompt(text_in, style_in):
907
 
908
- prompt = 'A campfire (oil on canvas)'
909
- style_seed = 32
910
- dict_styles = {'<gartic-phone>':'learned_embeds_gartic-phone.bin',
911
- '<hawaiian shirt>':'learned_embeds_hawaiian-shirt.bin',
912
- '<gp>': 'learned_embeds_phone01.bin',
913
- '<style-spdmn>':'learned_embeds_style-spdmn.bin',
914
- '<yvmqznrm>': 'learned_embedssd_yvmqznrm.bin'}
915
-
916
- learn_embed = ['learned_embeds_gartic-phone.bin', 'learned_embeds_hawaiian-shirt_style.bin', 'learned_embeds_phone01_style.bin', 'learned_embeds_style-spdmn_style.bin', 'learned_embedssd_yvmqznrm_style.bin']
917
- style = dict_styles # (learn_embed[0])
918
- birb_embed = torch.load(learn_embed[0])
919
- #birb_embed.keys(), dict_styles['<gartic-phone>'].shape
920
-
921
-
922
- #style_embed = torch.load(dict_styles)
923
- #birb_embed = torch.load('learned_embeds.bin')
924
- #birb_embed.keys(), birb_embed['<birb-style>'].shape
925
- generated_image = generate_embed_style(prompt,birb_embed, style_seed)
926
- generate_loss_details = (generate_loss_style(prompt, birb_embed, style_seed))
927
- #generate_loss_style(prompt, style_embed, style_seed):
928
-
929
- #loss_generated_img = (loss_style(prompt, style_embed[0], style_seed))
930
-
931
- return [generated_image,generate_loss_details]
932
 
933
 
934
  # Define Interface
 
905
 
906
  def generate_image_from_prompt(text_in, style_in):
907
 
908
+ STYLE_LIST = ['learned_embeds_gartic-phone_style.bin', 'learned_embeds_hawaiian-shirt_style.bin', 'learned_embeds_phone01_style.bin', 'learned_embeds_style-spdmn_style.bin', 'learned_embedssd_yvmqznrm_style.bin']
909
+ #learned_embeds = [learned_embeds_gartic-phone.bin,learned_embeds_libraryhawaiian-shirt.bin, learned_embeds_phone0.bin1,learned_embeds_style-spdmn.bin,learned_embedssd_yvmqznrm.bin]
910
+
911
+ STYLE_SEEDS = [128, 64, 128, 64, 128]
912
+
913
+ print(text_in)
914
+ print(style_in)
915
+ style_file = style_in + '_style.bin'
916
+ idx = STYLE_LIST.index(style_file)
917
+ print(style_file)
918
+ print(idx)
919
+
920
+ prompt = text_in
921
+
922
+ style_seed = STYLE_SEEDS[idx]
923
+ style_dict = torch.load(style_file)
924
+ style_embed = [v for v in style_dict.values()]
925
+
926
+ generated_image = embed_style(prompt, style_embed[0], style_seed)
927
+
928
+ loss_generated_img = (loss_style(prompt, style_embed[0], style_seed))
929
+
930
+ return [generated_image]
 
931
 
932
 
933
  # Define Interface