nightfury commited on
Commit
f2df246
β€’
1 Parent(s): d25b5a8

Create appli.py

Browse files
Files changed (1) hide show
  1. appli.py +57 -0
appli.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+
4
+ import torch
5
+ import gradio as gr
6
+ from clip_interrogator import Config, Interrogator
7
+
8
+
9
+ CACHE_URLS = [
10
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_artists.pkl',
11
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_flavors.pkl',
12
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_mediums.pkl',
13
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_movements.pkl',
14
+ 'https://huggingface.co/pharma/ci-preprocess/resolve/main/ViT-H-14_laion2b_s32b_b79k_trendings.pkl',
15
+ ]
16
+ os.makedirs('cache', exist_ok=True)
17
+ for url in CACHE_URLS:
18
+ subprocess.run(['wget', url, '-P', 'cache'], stdout=subprocess.PIPE).stdout.decode('utf-8')
19
+
20
+
21
+ config = Config()
22
+ config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
+ config.blip_offload = False if torch.cuda.is_available() else True
24
+ config.chunk_size = 2048
25
+ config.flavor_intermediate_count = 512
26
+ config.blip_num_beams = 64
27
+ ci = Interrogator(config)
28
+
29
+
30
+ def inference(image, mode, best_max_flavors):
31
+ image = image.convert('RGB')
32
+ if mode == 'best':
33
+ prompt_result = ci.interrogate(image, max_flavors=int(best_max_flavors))
34
+ elif mode == 'classic':
35
+ prompt_result = ci.interrogate_classic(image)
36
+ else:
37
+ prompt_result = ci.interrogate_fast(image)
38
+ return prompt_result
39
+
40
+
41
+ with gr.Blocks() as demo:
42
+ with gr.Column():
43
+ gr.Markdown("# CLIP Interrogator")
44
+ input_image = gr.Image(type='pil', elem_id="input-img")
45
+ with gr.Row():
46
+ mode_input = gr.Radio(['best', 'classic', 'fast'], label='Select mode', value='best')
47
+ flavor_input = gr.Slider(minimum=2, maximum=48, step=2, value=32, label='best mode max flavors')
48
+ submit_btn = gr.Button("Submit")
49
+ output_text = gr.Textbox(label="Description Output")
50
+ submit_btn.click(
51
+ fn=inference,
52
+ inputs=[input_image, mode_input, flavor_input],
53
+ outputs=[output_text],
54
+ concurrency_limit=10
55
+ )
56
+
57
+ demo.queue().launch()