Spaces:
Paused
Paused
import base64 | |
import io | |
import math | |
import os | |
import random | |
import json | |
import re | |
from typing import List, Tuple | |
import PIL | |
import gradio as gr | |
import outlines | |
import requests | |
from outlines import models, generate, samplers | |
from pydantic import BaseModel | |
# Constants | |
MAX_IMAGE_SIZE = (1024, 1024) | |
TARGET_IMAGE_SIZE = 180_000 | |
NVIDIA_API_URL = "https://ai.api.nvidia.com/v1/vlm/microsoft/phi-3-vision-128k-instruct" | |
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct" | |
class Clue(BaseModel): | |
word: str | |
explanation: str | |
class Group(BaseModel): | |
words: List[str] | |
clue: str | |
explanation: str | |
class Groups(BaseModel): | |
groups: List[Group] | |
example_clues = [ | |
(['ARROW', 'TIE', 'HONOR'], 'BOW', 'such as a bow and arrow, a bow tie, or a bow as a sign of honor'), | |
(['DOG', 'TREE'], 'BARK', 'such as the sound a dog makes, or a tree is made of bark'), | |
(['MONEY', 'RIVER', 'ROB', 'BLOOD'], 'CRIME', 'such as money being stolen, a river being a potential crime scene, ' | |
'robbery, or blood being a result of a violent crime'), | |
(['BEEF', 'TURKEY', 'FIELD', 'GRASS'], 'GROUND', | |
'such as ground beef, a turkey being a ground-dwelling bird, a field or grass being a type of ground'), | |
(['BANK', 'GUITAR', 'LIBRARY'], 'NOTE', | |
'such as a bank note, a musical note on a guitar, or a note being a written comment in a library book'), | |
(['ROOM', 'PIANO', 'TYPEWRITER'], 'KEYS', 'such as a room key, piano keys, or typewriter keys'), | |
(['TRAFFIC', 'RADAR', 'PHONE'], 'SIGNAL', 'such as traffic signals, radar signals, or phone signals'), | |
(['FENCE', 'PICTURE', 'COOKIE'], 'FRAME', | |
'such as a frame around a yard, a picture frame, or a cookie cutter being a type of frame'), | |
(['YARN', 'VIOLIN', 'DRESS'], 'STRING', 'strings like material, instrument, clothing fastener'), | |
(['JUMP', 'FLOWER', 'CLOCK'], 'SPRING', | |
'such as jumping, flowers blooming in the spring, or a clock having a sprint component'), | |
(['SPY', 'KNIFE'], 'WAR', | |
'Both relate to aspects of war, such as spies being involved in war or knives being used as weapons'), | |
(['STADIUM', 'SHOE', 'FIELD'], 'SPORT', 'Sports like venues, equipment, playing surfaces'), | |
(['TEACHER', 'CLUB'], 'SCHOOL', | |
'such as a teacher being a school staff member or a club being a type of school organization'), | |
(['CYCLE', 'ARMY', 'COURT', 'FEES'], 'CHARGE', 'charges like electricity, battle, legal, payments'), | |
(['FRUIT', 'MUSIC', 'TRAFFIC', 'STUCK'], 'JAM', | |
'Jams such as fruit jam, a music jam session, traffic jam, or being stuck in a jam'), | |
(['POLICE', 'DOG', 'THIEF'], 'CRIME', | |
'such as police investigating crimes, dogs being used to detect crimes, or a thief committing a crime'), | |
(['ARCTIC', 'SHUT', 'STAMP'], 'SEAL', | |
'such as the Arctic being home to seals, or shutting a seal on an envelope, or a stamp being a type of seal'), | |
] | |
def create_random_word_groups(clues: List[Tuple[List[str], str, str]], target_groups: int = 10) -> List[Tuple[List[str], List[int]]]: | |
""" | |
Creates approximately 'target_groups' random groups of words from the given clues. | |
Args: | |
clues: A list of clues, where each clue is a tuple (words, answer, explanation). | |
target_groups: The desired number of groups to create. | |
Returns: | |
A list of tuples, each containing a list of merged words and their corresponding indices. | |
""" | |
groups = [] | |
while len(groups) < target_groups: | |
num_rows = random.choice([3, 4]) | |
selected_indices = random.sample(range(len(clues)), num_rows) | |
merged_words = [word for row in [clues[i][0] for i in selected_indices] for word in row] | |
if len(merged_words) in [8, 9]: | |
groups.append((merged_words, selected_indices)) | |
return groups | |
def group_words(word_list: List[str]) -> List[Group]: | |
""" | |
Groups the given words into 3 to 4 thematic groups. | |
Args: | |
word_list: A list of words to be grouped. | |
Returns: | |
A list of Group objects representing the grouped words. | |
""" | |
def chat_group_template(system_prompt, query, history=[]): | |
'''<s><|system|> | |
{{ system_prompt }} | |
{% for example in history %} | |
<|user|> | |
{{ example[0] }}<|end|> | |
<|assistant|> | |
{{ example[1] }}<|end|> | |
{% endfor %} | |
<|user|> | |
{{ query }}<|end|> | |
<|assistant|> | |
''' | |
grouping_system_prompt = ("You are an assistant for the game Codenames. Your task is to help players by grouping a " | |
"given set of words into 3 to 4 groups. Each group should consist of words that " | |
"share a common theme or other word connections such as homonyms, hypernyms, or synonyms.") | |
example_groupings = [] | |
merges = create_random_word_groups(example_clues, 5) | |
for merged_words, indices in merges: | |
groups = [{ | |
"secrets": example_clues[i][0], | |
"clue": example_clues[i][1], | |
"explanation": example_clues[i][2] | |
} for i in indices] | |
example_groupings.append((merged_words, json.dumps(groups, separators=(',', ':')))) | |
prompt = chat_group_template(grouping_system_prompt, word_list, example_groupings) | |
sampler = samplers.greedy() | |
generator = generate.json(model, Groups, sampler) | |
print(f"Grouping words: {word_list}") | |
generations = generator(prompt, max_tokens=500) | |
print(f"Generated groupings: {generations}") | |
return generations.groups | |
def generate_clue(group: List[str]) -> Clue: | |
""" | |
Generates a single-word clue for the given group of words. | |
Args: | |
group: A list of words to generate a clue for. | |
Returns: | |
A Clue object containing the generated word and its explanation. | |
""" | |
def chat_clue_template(system, query, history=[]): | |
'''<s><|system|> | |
{{ system }} | |
{% for example in history %} | |
<|user|> | |
{{ example[0] }}<|end|> | |
<|assistant|> | |
{"Clue": "{{ example[1] }}", "Description": "{{ example[2] }}" }<|end|> | |
{% endfor %} | |
<|user|> | |
{{ query }}<|end|> | |
<|assistant|> | |
''' | |
clue_system_prompt = ("You are a Codenames game companion. Your task is to give a single word clue related to " | |
"a given group of words. Respond with a single word clue only. Compound words are " | |
"allowed. Do not include the word 'Clue'. Do not provide explanations or notes.") | |
prompt = chat_clue_template(clue_system_prompt, group, example_clues) | |
sampler = samplers.multinomial(2, top_k=10) | |
generator = generate.json(model, Clue, sampler) | |
generations = generator(prompt, max_tokens=100) | |
print(f"Generated clues: {generations}") | |
return generations[0] | |
def compress_image_to_jpeg(image: 'PIL.Image', target_size: int) -> bytes: | |
""" | |
Compresses the image to JPEG format with the best quality that fits within the target size. | |
https://stackoverflow.com/a/52281257 | |
Args: | |
image: The PIL Image object to compress. | |
target_size: The target file size in bytes. | |
Returns: | |
The compressed image as bytes. | |
""" | |
# Min and Max quality | |
qmin, qmax = 25, 96 | |
# Highest acceptable quality found | |
qacc = -1 | |
while qmin <= qmax: | |
m = math.floor((qmin + qmax) / 2) | |
# Encode into memory and get size | |
buffer = io.BytesIO() | |
image.save(buffer, format="JPEG", quality=m) | |
s = buffer.getbuffer().nbytes | |
if s <= target_size: | |
qacc = m | |
qmin = m + 1 | |
elif s > target_size: | |
qmax = m - 1 | |
# Write to disk at the defined quality | |
if qacc > -1: | |
image_byte_array = io.BytesIO() | |
print("Acceptable quality", image, image.format, f"{image.size}x{image.mode}") | |
image.save(image_byte_array, format='JPEG', quality=qacc) | |
return image_byte_array.getvalue() | |
def process_image(img: 'PIL.Image') -> gr.update: | |
""" | |
Processes the uploaded image to detect words for the Codenames game. | |
Args: | |
img: The uploaded PIL Image object. | |
Returns: | |
A gradio update object with the detected words. | |
""" | |
img.thumbnail(MAX_IMAGE_SIZE) | |
image_byte_array = compress_image_to_jpeg(img, TARGET_IMAGE_SIZE) | |
image_b64 = base64.b64encode(image_byte_array).decode() | |
headers = { | |
"Authorization": f"Bearer {os.environ.get('NVIDIA_API_KEY', '')}", | |
"Accept": "application/json" | |
} | |
payload = { | |
"messages": [ | |
{ | |
"role": "user", | |
"content": f'Identify the words in this game of Codenames. Provide only a list of words in capital letters. <img src="data:image/png;base64,{image_b64}" />' | |
} | |
], | |
"max_tokens": 512, | |
"temperature": 0.1, | |
"top_p": 0.70, | |
"stream": False | |
} | |
response = requests.post(NVIDIA_API_URL, headers=headers, json=payload) | |
if response.ok: | |
print(response.json()) | |
pattern = r'[A-Z]+(?:\s+[A-Z]+)?' | |
words = re.findall(pattern, response.json()['choices'][0]['message']['content']) | |
return gr.update(choices=words, value=words) | |
def pad_or_truncate_groups(groups: List[Group], target_length: int = 4) -> List[Group]: | |
""" | |
Ensures the list of groups has exactly target_length elements, padding with empty Groups if necessary. | |
Args: | |
groups: The list of Group objects to pad or truncate. | |
target_length: The desired length of the list. | |
Returns: | |
A list of Group objects with the specified length. | |
""" | |
truncated_groups = groups[:target_length] | |
return truncated_groups + [Group(words=[], clue='', explanation='') for _ in range(target_length - len(truncated_groups))] | |
def group_words_callback(words: List[str]) -> List[gr.update]: | |
""" | |
Callback function to group the selected words. | |
Args: | |
words: A list of words to group. | |
Returns: | |
A list of gradio update objects for each group input. | |
""" | |
groups = group_words(words) | |
groups = pad_or_truncate_groups(groups, 4) | |
print(f"Generated groups: {groups}") | |
return [gr.update(value=group.words, choices=group.words, info=group.explanation) for group in groups] | |
if __name__ == '__main__': | |
with gr.Blocks() as demo: | |
gr.Markdown("# *Codenames* clue generator") | |
gr.Markdown("Provide a list of words to generate a clue") | |
with gr.Row(): | |
game_image = gr.Image(type="pil") | |
word_list_input = gr.Dropdown(label="Enter list of words (comma separated)", | |
choices=[], | |
multiselect=True, | |
interactive=True) | |
with gr.Row(): | |
detect_words_button = gr.Button("Detect Words") | |
group_words_button = gr.Button("Group Words") | |
dropdowns, buttons, outputs = [], [], [] | |
for i in range(4): | |
with gr.Row(): | |
group_input = gr.Dropdown(label=f"Group {i + 1}", | |
choices=[], | |
allow_custom_value=True, | |
multiselect=True, | |
interactive=True) | |
clue_button = gr.Button("Generate Clue", size='sm') | |
clue_output = gr.Textbox(label=f"Clue {i + 1}") | |
dropdowns.append(group_input) | |
buttons.append(clue_button) | |
outputs.append(clue_output) | |
def pad_or_truncate(lst, n=4): | |
# Ensure the length of the list is at most n | |
truncated_lst = lst[:n] | |
return truncated_lst + (n - len(truncated_lst)) * [Group(words=[],clue='',explanation='')] | |
def group_words_callback(words): | |
groups = group_words(words) | |
groups = pad_or_truncate(groups, 4) | |
print("Got groups: ", groups, type(groups)) | |
return [gr.update(value=groups[i].words, choices=groups[i].words, info=groups[i].explanation) for i in range(4)] | |
def generate_clues_callback(group): | |
print("Generating clues: ", group) | |
g = generate_clue(group) | |
return gr.update(value=g.word, info=g.explanation) | |
model = models.transformers("microsoft/Phi-3-mini-4k-instruct", | |
model_kwargs={'device_map': "cuda", 'torch_dtype': "auto", | |
'trust_remote_code': True, | |
'attn_implementation': "flash_attention_2"}) | |
detect_words_button.click(fn=process_image, | |
inputs=game_image, | |
outputs=[word_list_input]) | |
group_words_button.click(fn=group_words_callback, | |
inputs=word_list_input, | |
outputs=dropdowns) | |
for i in range(4): | |
buttons[i].click(generate_clues_callback, inputs=dropdowns[i], outputs=outputs[i]) | |
demo.launch(share=False) | |