Spaces:
Sleeping
Sleeping
import base64 | |
import requests | |
import json | |
import pandas as pd | |
import os | |
from tqdm import tqdm | |
import re | |
import torch | |
import io | |
from PIL import Image | |
def image_to_bytes(image): | |
"""Convert PIL Image to bytes.""" | |
buffer = io.BytesIO() | |
image.save(buffer, format="JPEG") # Adjust format if necessary | |
return buffer.getvalue() | |
def query_clip(data, hf_token): | |
API_URL = "https://api-inference.huggingface.co/models/openai/clip-vit-base-patch32" | |
headers = {"Authorization": f"Bearer {hf_token}"} | |
img = data['image'] | |
img_bytes = image_to_bytes(img) | |
image = Image.open(io.BytesIO(img_bytes)) | |
encoded_img = base64.b64encode(img_bytes).decode("utf-8") | |
payload={ | |
"parameters": data["parameters"], | |
"inputs": encoded_img | |
} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
def get_sentiment(img, hf_token): | |
print("Getting the sentiment of the image...") | |
output = query_clip({ | |
"image": img, | |
"parameters": {"candidate_labels": ["angry", "happy"]}, | |
}, hf_token) | |
try: | |
print("Sentiment:", output[0]['label']) | |
return output[0]['label'] | |
except: | |
print(output) | |
print("If the model is loading, try again in a minute. If you've reached a query limit (300 per hour), try within the next hour.") | |
def query_blip(img, hf_token): | |
API_URL = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large" | |
headers = {"Authorization": f"Bearer {hf_token}"} | |
img_bytes = image_to_bytes(img) | |
files = { | |
'file': ('image.jpg', img_bytes, 'image/jpeg') | |
} | |
response = requests.post(API_URL, headers=headers, data=files) | |
return response.json() | |
def get_description(img, hf_token): | |
print("Getting the context of the image...") | |
output = query_blip(img, hf_token) | |
try: | |
print("Context:", output[0]['generated_text']) | |
return output[0]['generated_text'] | |
except: | |
print(output) | |
print("The model is not available right now due to query limits. Try running again now or within the next hour") | |
def get_model_caption(img_path, base_model, tokenizer, hf_token, device='cuda'): | |
sentiment = get_sentiment(img_path, hf_token) | |
description = get_description(img_path, hf_token) | |
prompt_template = """ | |
Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n\\n | |
You are given a topic. Your task is to generate a meme caption based on the topic. Only output the meme caption and nothing more. | |
Topic: {query} | |
<end_of_turn>\\n<start_of_turn>model Caption: | |
""" | |
prompt = prompt_template.format(query=description) | |
print("Generating captions...") | |
encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) | |
model_inputs = encodeds.to(device) | |
print("sentiment", sentiment) | |
base_model.set_adapter(sentiment) | |
base_model.to(device) | |
generated_ids = base_model.generate(**model_inputs, max_new_tokens=20, do_sample=True, pad_token_id=tokenizer.eos_token_id) | |
decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
return (decoded) |