Spaces:
Runtime error
Runtime error
Commit
·
4d75d06
1
Parent(s):
2f86fc5
Update app.py
Browse filesthings were done
app.py
CHANGED
@@ -14,14 +14,17 @@ tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
|
|
14 |
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
|
15 |
|
16 |
|
17 |
-
def predict(image
|
|
|
18 |
image = image.convert('RGB')
|
19 |
image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
|
20 |
clean_text = lambda x: x.replace('<|endoftext|>', '').split('\n')[0]
|
21 |
-
caption_ids = model.generate(image, max_length=
|
22 |
img_caption_text = clean_text(tokenizer.decode(caption_ids))
|
|
|
23 |
caption_text = creative_caption(img_caption_text)
|
24 |
hashtags = caption_hashtags(img_caption_text)
|
|
|
25 |
return caption_text, hashtags
|
26 |
|
27 |
|
@@ -33,13 +36,13 @@ def caption_hashtags(text):
|
|
33 |
return co_client.generate(prompt=f"Write some trendy instagram hashtags for the following prompt - {text}")
|
34 |
|
35 |
|
36 |
-
input_upload = gr.
|
37 |
output = [
|
38 |
-
gr.
|
39 |
-
gr.
|
40 |
]
|
41 |
|
42 |
-
title = "Image Captioning
|
43 |
description = "Made for Linesh"
|
44 |
interface = gr.Interface(
|
45 |
|
|
|
14 |
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
|
15 |
|
16 |
|
17 |
+
def predict(image):
|
18 |
+
"""Predict the generic image caption from the image """
|
19 |
image = image.convert('RGB')
|
20 |
image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
|
21 |
clean_text = lambda x: x.replace('<|endoftext|>', '').split('\n')[0]
|
22 |
+
caption_ids = model.generate(image, max_length=125)[0]
|
23 |
img_caption_text = clean_text(tokenizer.decode(caption_ids))
|
24 |
+
|
25 |
caption_text = creative_caption(img_caption_text)
|
26 |
hashtags = caption_hashtags(img_caption_text)
|
27 |
+
|
28 |
return caption_text, hashtags
|
29 |
|
30 |
|
|
|
36 |
return co_client.generate(prompt=f"Write some trendy instagram hashtags for the following prompt - {text}")
|
37 |
|
38 |
|
39 |
+
input_upload = gr.Image(label="Upload any Image", type='pil', optional=True)
|
40 |
output = [
|
41 |
+
gr.Textbox(type="auto", label="Captions"),
|
42 |
+
gr.Textbox(type="auto", label="Hashtags"),
|
43 |
]
|
44 |
|
45 |
+
title = "Instagram Image Captioning"
|
46 |
description = "Made for Linesh"
|
47 |
interface = gr.Interface(
|
48 |
|