sneha commited on
Commit
30f2229
·
1 Parent(s): cf37148

fix repo_name, add css

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -57,9 +57,10 @@ def download_bin(model):
57
  else:
58
  raise NameError("model not found: " + model)
59
 
 
60
  bin_path = os.path.join(MODEL_DIR,bin_file)
61
  if not os.path.isfile(bin_path):
62
- model_bin = hf_hub_download(repo_id=REPO_ID, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
63
  os.rename(model_bin, bin_path)
64
 
65
 
@@ -95,11 +96,12 @@ input_img = gr.Image(shape=(250,250))
95
  input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
96
  output_img = gr.Image(shape=(250,250))
97
  output_plot = gr.Plot()
 
98
 
99
  markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \
100
  The user can decide how the attention heads will be combined. \
101
  Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid."
102
  demo = gr.Interface(fn=run_attn, title="Visual Cortex Large Model", description=markdown,
103
  examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
104
- inputs=[input_img,model_type,input_button],outputs=[output_img,output_plot])
105
  demo.launch()
 
57
  else:
58
  raise NameError("model not found: " + model)
59
 
60
+ repo_name = 'facebook/' + model
61
  bin_path = os.path.join(MODEL_DIR,bin_file)
62
  if not os.path.isfile(bin_path):
63
+ model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
64
  os.rename(model_bin, bin_path)
65
 
66
 
 
96
  input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
97
  output_img = gr.Image(shape=(250,250))
98
  output_plot = gr.Plot()
99
+ css = ".output-image, .input-image, .image-preview {height: 600px !important}"
100
 
101
  markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \
102
  The user can decide how the attention heads will be combined. \
103
  Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid."
104
  demo = gr.Interface(fn=run_attn, title="Visual Cortex Large Model", description=markdown,
105
  examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
106
+ inputs=[input_img,model_type,input_button],outputs=[output_img,output_plot],css=css)
107
  demo.launch()