Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
"""get_caption.py: """ | |
__author__ = "Rishabh Gupta" | |
__copyright__ = "Copyright 2023, Zuma" | |
__created_date__ = "04-08-2023" | |
import os | |
import json | |
import requests | |
import base64 | |
from io import BytesIO | |
import PIL | |
def get_image_caption(img_obj, prompt=None): | |
""" | |
Args: | |
img_obj: Base 64 image object | |
prompt: Prompt to be used for captioning | |
Returns: | |
Captions for the image | |
""" | |
# print("Image Obj", img_obj) | |
buffered = BytesIO() | |
if img_obj is None: | |
return "No image data provided!" | |
try: | |
img_obj.save(buffered, format="JPEG") | |
except PIL.UnidentifiedImageError as e: | |
print("Error in saving the image", e) | |
return "Invalid image format!" | |
except Exception as e: | |
print("Some error occurred while saving the image", e) | |
return "Some error occurred while loading the image!" | |
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') | |
url = os.environ.get("IMAGE_CAPTION_ENDPOINT") | |
auth_token = os.environ.get("IMAGE_CAPTION_AUTH_TOKEN") | |
# print("img_str", img_str) | |
payload = json.dumps({ | |
"inputs": img_str, | |
"prompt": None # Prompts are disabled as accuracy is not very good or not experimented enough. | |
}) | |
headers = { | |
'Authorization': auth_token, | |
'Content-Type': 'application/json' | |
} | |
response = requests.request("POST", url, headers=headers, data=payload) | |
print("Response Status Code = ", response.status_code) | |
print("Response = ", response.text) | |
if response.status_code == 200: | |
caption = json.loads(response.text)["captions"] | |
print(f"Caption = {caption}") | |
return caption | |
elif response.status_code == 502: | |
print("Status 502") | |
return "Model is getting loaded, please wait for 3-4 minutes and then try again." | |
else: | |
return "Error in generating the caption!" | |
if __name__ == '__main__': | |
with open("/Users/rishabh/Downloads/5dbadb934042f1.16205085439.jpg", "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
print(get_image_caption(encoded_string)) | |