sneha
commited on
Commit
·
30f2229
1
Parent(s):
cf37148
fix repo_name, add css
Browse files
app.py
CHANGED
@@ -57,9 +57,10 @@ def download_bin(model):
|
|
57 |
else:
|
58 |
raise NameError("model not found: " + model)
|
59 |
|
|
|
60 |
bin_path = os.path.join(MODEL_DIR,bin_file)
|
61 |
if not os.path.isfile(bin_path):
|
62 |
-
model_bin = hf_hub_download(repo_id=
|
63 |
os.rename(model_bin, bin_path)
|
64 |
|
65 |
|
@@ -95,11 +96,12 @@ input_img = gr.Image(shape=(250,250))
|
|
95 |
input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
|
96 |
output_img = gr.Image(shape=(250,250))
|
97 |
output_plot = gr.Plot()
|
|
|
98 |
|
99 |
markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \
|
100 |
The user can decide how the attention heads will be combined. \
|
101 |
Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid."
|
102 |
demo = gr.Interface(fn=run_attn, title="Visual Cortex Large Model", description=markdown,
|
103 |
examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
|
104 |
-
inputs=[input_img,model_type,input_button],outputs=[output_img,output_plot])
|
105 |
demo.launch()
|
|
|
57 |
else:
|
58 |
raise NameError("model not found: " + model)
|
59 |
|
60 |
+
repo_name = 'facebook/' + model
|
61 |
bin_path = os.path.join(MODEL_DIR,bin_file)
|
62 |
if not os.path.isfile(bin_path):
|
63 |
+
model_bin = hf_hub_download(repo_id=repo_name, filename='pytorch_model.bin',local_dir=MODEL_DIR,local_dir_use_symlinks=True,token=HF_TOKEN)
|
64 |
os.rename(model_bin, bin_path)
|
65 |
|
66 |
|
|
|
96 |
input_button = gr.Radio(["min", "max", "mean"], value="min",label="Attention Head Fusion", info="How to combine the last layer attention across all 12 heads of the transformer.")
|
97 |
output_img = gr.Image(shape=(250,250))
|
98 |
output_plot = gr.Plot()
|
99 |
+
css = ".output-image, .input-image, .image-preview {height: 600px !important}"
|
100 |
|
101 |
markdown ="This is a demo for the Visual Cortex models. When passed an image input, it displays the attention of the last layer of the transformer.\n \
|
102 |
The user can decide how the attention heads will be combined. \
|
103 |
Along with the attention heatmap, it also displays the embedding values reshaped to a 16x48 or 16x64 grid."
|
104 |
demo = gr.Interface(fn=run_attn, title="Visual Cortex Large Model", description=markdown,
|
105 |
examples=[[os.path.join('./imgs',x),None,None]for x in os.listdir(os.path.join(os.getcwd(),'imgs')) if 'jpg' in x],
|
106 |
+
inputs=[input_img,model_type,input_button],outputs=[output_img,output_plot],css=css)
|
107 |
demo.launch()
|