sophiaaez commited on
Commit
b323982
·
1 Parent(s): 8c63a0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -20
app.py CHANGED
@@ -43,32 +43,53 @@ model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')
43
  model_vq.eval()
44
  model_vq = model_vq.to(device)
45
 
46
-
47
-
48
- def inference(raw_image, model_n, question, strategy):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  if model_n == 'Image Captioning':
50
  image = transform(raw_image).unsqueeze(0).to(device)
51
  with torch.no_grad():
52
- if strategy == "Beam search":
53
- caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
54
- else:
55
- caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
56
- return 'caption: '+caption[0]
57
-
58
- else:
59
- image_vq = transform_vq(raw_image).unsqueeze(0).to(device)
60
- with torch.no_grad():
61
- answer = model_vq(image_vq, question, train=False, inference='generate')
62
- return 'answer: '+answer[0]
 
63
 
64
- inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=['Image Captioning',"Visual Question Answering"], type="value", default="Image Captioning", label="Task"),gr.inputs.Textbox(lines=2, label="Question"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Nucleus sampling", label="Caption Decoding Strategy")]
65
- outputs = gr.outputs.Textbox(label="Output")
66
 
67
- title = "BLIP"
68
 
69
- description = "Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation (Salesforce Research). To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
70
 
71
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation</a> | <a href='https://github.com/salesforce/BLIP' target='_blank'>Github Repo</a></p>"
72
 
 
73
 
74
- gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['starrynight.jpeg',"Image Captioning","None","Nucleus sampling"]]).launch(enable_queue=True)
 
43
  model_vq.eval()
44
  model_vq = model_vq.to(device)
45
 
46
+ def getModelPath(language):
47
+ if language == 'English':
48
+ path = None
49
+ elif language == 'German':
50
+ path = "Helsinki-NLP/opus-mt-en-de"
51
+ elif language == 'French':
52
+ path = "Helsinki-NLP/opus-mt-en-fr"
53
+ elif language == 'Spanish':
54
+ path = "Helsinki-NLP/opus-mt-en-es"
55
+ elif language == 'Chinese':
56
+ path = "Helsinki-NLP/opus-mt-en-zh"
57
+ elif language == 'Ukranian':
58
+ path = "Helsinki-NLP/opus-mt-en-uk"
59
+ elif language == 'Swedish':
60
+ path = "Helsinki-NLP/opus-mt-en-sv"
61
+ elif language == 'Arabic':
62
+ path = "Helsinki-NLP/opus-mt-en-ar"
63
+ elif language == 'Italian':
64
+ path = "Helsinki-NLP/opus-mt-en-it"
65
+ elif language == 'Hindi':
66
+ path = "Helsinki-NLP/opus-mt-en-hi"
67
+ return(path)
68
+
69
+ def inference(input_img,strategy,language):
70
  if model_n == 'Image Captioning':
71
  image = transform(raw_image).unsqueeze(0).to(device)
72
  with torch.no_grad():
73
+ if strategy == "Beam search":
74
+ cap = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
75
+ else:
76
+ cap = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
77
+ if modelpath:
78
+ translator = pipeline("translation", model=modelpath)
79
+ trans_cap = translator(cap[0])
80
+ tc = trans_cap[0]['translation_text']
81
+ return str(tc)
82
+ else:
83
+ return str(cap[0])
84
+
85
 
86
+ description = "A pipeline of BLIP image captioning and Helsinki translation in order to generate image captions in a language of your choice either with beam search (deterministic) or nucleus sampling (stochastic). Enjoy! Is the language you want to use missing? Let me know and I'll integrate it."
 
87
 
 
88
 
89
+ inputs_ = [gr.inputs.Image(type='filepath', label="Input Image"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Nucleus sampling", label="Mode"), gr.inputs.Radio(choices=['English','German', 'French', 'Spanish', 'Chinese', 'Ukranian', 'Swedish', 'Arabic', 'Italian', 'Hindi'],type="value", default = 'German',label="Language")]
90
 
91
+ outputs_ = gr.outputs.Textbox(label="Output")
92
 
93
+ iface = gr.Interface(inference, inputs_, outputs_, description=description)
94
 
95
+ iface.launch(debug=True,show_error=True)