yash-srivastava19 commited on
Commit
e9757be
·
1 Parent(s): 6bc7c0f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cohere
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
4
+
5
+ co_client = cohere.Client('29JdDGuDUqPx2jqTkQUtsJqZRIwUoqwPKd2j9CRA')
6
+
7
+ device = 'cpu'
8
+ encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
9
+ decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
10
+ model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
11
+ feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
14
+ model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
15
+
16
+
17
+ def predict(image, max_length=64, num_beams=4):
18
+ image = image.convert('RGB')
19
+ image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
20
+ clean_text = lambda x: x.replace('<|endoftext|>', '').split('\n')[0]
21
+ caption_ids = model.generate(image, max_length=max_length)[0]
22
+ img_caption_text = clean_text(tokenizer.decode(caption_ids))
23
+ caption_text = creative_caption(img_caption_text)
24
+ hashtags = caption_hashtags(img_caption_text)
25
+ return caption_text, hashtags
26
+
27
+
28
+ def creative_caption(text):
29
+ return co_client.generate(prompt=f"Write some trendy instagram captions for the following prompt - {text}")
30
+
31
+
32
+ def caption_hashtags(text):
33
+ return co_client.generate(prompt=f"Write some trendy instagram hashtags for the following prompt - {text}")
34
+
35
+
36
+ input_upload = gr.inputs.Image(label="Upload any Image", type='pil', optional=True)
37
+ output = [
38
+ gr.outputs.Textbox(type="auto", label="Captions"),
39
+ gr.outputs.Textbox(type="auto", label="Hashtags"),
40
+ ]
41
+
42
+ title = "Image Captioning "
43
+ description = "Made for Linesh"
44
+ interface = gr.Interface(
45
+
46
+ fn=predict,
47
+ description=description,
48
+ inputs=input_upload,
49
+ theme="grass",
50
+ outputs=output,
51
+ title=title,
52
+ )
53
+ interface.launch(debug=True)