Spaces:
Runtime error
Runtime error
from secrets_key import OPENAI_KEY, RANDOM_SEED | |
from openai import OpenAI | |
import json | |
import pandas as pd | |
from pprint import pprint | |
client = OpenAI(api_key=OPENAI_KEY) | |
prompt = """ | |
You are given a story and 3 images related to the story. Identify a person/object that can be visually identified in the images but not directly mentioned on the story. Use as few words as possible to describe each person/object. Also, mention the image number (1, 2 or 3) where the person/object can be found. | |
Output in a python list of dictionaries. Each dictionary should have the following keys: 'image_number', 'person/object'. | |
Story: {story} | |
""" | |
def get_entity_gpt4V(row): | |
story = row['Input.story'] | |
now_prompt = prompt.format(story=story) | |
content = [ | |
{"type": "text", "text": now_prompt}, | |
] | |
images = [] | |
for i in range(1,4): | |
image_url = row[f'Input.image{i}'] | |
images.append(image_url) | |
content.append({ | |
"type": "image_url", | |
"image_url": { | |
"url": image_url, | |
}, | |
}) | |
response = client.chat.completions.create( | |
model="gpt-4-vision-preview", | |
seed=RANDOM_SEED, | |
messages=[ | |
{ | |
"role": "user", | |
"content": content | |
} | |
], | |
temperature=1, | |
max_tokens=256, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
) | |
print(row['HITId']) | |
print(now_prompt) | |
pprint(images) | |
out = response.choices[0].message.content | |
print("OUTPUT:", out) | |
print("====================================") | |
print() | |
if __name__ == '__main__': | |
df = pd.read_csv('./results.csv') | |
count = 0 | |
done = set() | |
for ind, row in df.iterrows(): | |
item_id = row['Input.item_id'] | |
if item_id in done: | |
continue | |
done.add(item_id) | |
get_entity_gpt4V(row) | |
count += 1 | |
if count == 10: | |
break |