hysts HF staff commited on
Commit
d7dfab0
1 Parent(s): 2a6a910

Use only Flan-T5-XXL

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -16,14 +16,14 @@ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
16
  MODEL_ID_OPT_6_7B = 'Salesforce/blip2-opt-6.7b'
17
  MODEL_ID_FLAN_T5_XXL = 'Salesforce/blip2-flan-t5-xxl'
18
  model_dict = {
19
- MODEL_ID_OPT_6_7B: {
20
- 'processor':
21
- AutoProcessor.from_pretrained(MODEL_ID_OPT_6_7B),
22
- 'model':
23
- Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_OPT_6_7B,
24
- device_map='auto',
25
- load_in_8bit=True),
26
- },
27
  MODEL_ID_FLAN_T5_XXL: {
28
  'processor':
29
  AutoProcessor.from_pretrained(MODEL_ID_FLAN_T5_XXL),
@@ -148,11 +148,13 @@ with gr.Blocks(css='style.css') as demo:
148
  model_id_caption = gr.Dropdown(
149
  label='Model ID for image captioning',
150
  choices=[MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL],
151
- value=MODEL_ID_OPT_6_7B)
 
152
  model_id_chat = gr.Dropdown(
153
  label='Model ID for VQA',
154
  choices=[MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL],
155
- value=MODEL_ID_FLAN_T5_XXL)
 
156
  sampling_method = gr.Radio(
157
  label='Text Decoding Method',
158
  choices=['Beam search', 'Nucleus sampling'],
 
16
  MODEL_ID_OPT_6_7B = 'Salesforce/blip2-opt-6.7b'
17
  MODEL_ID_FLAN_T5_XXL = 'Salesforce/blip2-flan-t5-xxl'
18
  model_dict = {
19
+ #MODEL_ID_OPT_6_7B: {
20
+ # 'processor':
21
+ # AutoProcessor.from_pretrained(MODEL_ID_OPT_6_7B),
22
+ # 'model':
23
+ # Blip2ForConditionalGeneration.from_pretrained(MODEL_ID_OPT_6_7B,
24
+ # device_map='auto',
25
+ # load_in_8bit=True),
26
+ #},
27
  MODEL_ID_FLAN_T5_XXL: {
28
  'processor':
29
  AutoProcessor.from_pretrained(MODEL_ID_FLAN_T5_XXL),
 
148
  model_id_caption = gr.Dropdown(
149
  label='Model ID for image captioning',
150
  choices=[MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL],
151
+ value=MODEL_ID_FLAN_T5_XXL,
152
+ interactive=False)
153
  model_id_chat = gr.Dropdown(
154
  label='Model ID for VQA',
155
  choices=[MODEL_ID_OPT_6_7B, MODEL_ID_FLAN_T5_XXL],
156
+ value=MODEL_ID_FLAN_T5_XXL,
157
+ interactive=False)
158
  sampling_method = gr.Radio(
159
  label='Text Decoding Method',
160
  choices=['Beam search', 'Nucleus sampling'],