xu3kev's picture
init
9bf1f45
raw
history blame
15.2 kB
import argparse
import json
import numpy as np
import gradio as gr
import requests
from openai import OpenAI
from func_timeout import FunctionTimedOut, func_timeout
from tqdm import tqdm
MOCK = True
TEST_FOLDER = "c4f5"
INPUT_STRUCTION_TEMPLATE = """Here is a gray scale images representing with integer values 0-9.
{image_str}
Please write a Python program that generates the image using our own custom turtle module"""
PROMPT_TEMPLATE = "### Instruction:\n{input_struction}\n### Response:\n"
TEST_IMAGE_STR ="00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000001222222000000000000\n00000000000002000002000000000000\n00000000000002022202000000000000\n00000000000002020202000000000000\n00000000000002020002000000000000\n00000000000002022223000000000000\n00000000000002000000000000000000\n00000000000002000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000"
MOCK_RESPONSE = [
"""for i in range(7):
with fork_state():
for j in range(4):
forward(2*i)
left(90.0)
"""
] * 16
LOGO_HEADER = """from myturtle import Turtle
from myturtle import HALF_INF, INF, EPS_DIST, EPS_ANGLE
turtle = Turtle()
def forward(dist):
turtle.forward(dist)
def left(angle):
turtle.left(angle)
def right(angle):
turtle.right(angle)
def teleport(x, y, theta):
turtle.teleport(x, y, theta)
def penup():
turtle.penup()
def pendown():
turtle.pendown()
def position():
return turtle.x, turtle.y
def heading():
return turtle.heading
def isdown():
return turtle.is_down
def fork_state():
\"\"\"
Fork the current state of the turtle.
Usage:
with fork_state():
forward(100)
left(90)
forward(100)
\"\"\"
return turtle._TurtleState(turtle)"""
def invert_colors(image):
"""
Inverts the colors of the input image.
Args:
- image (dict): Input image dictionary from Sketchpad.
Returns:
- numpy array: Color-inverted image array.
"""
# Extract image data from the dictionary and convert to NumPy array
image_data = image['layers'][0]
image_array = np.array(image_data)
# Invert colors
inverted_image = 255 - image_array
return inverted_image
def crop_image_to_center(image, target_height=512, target_width=512, detect_cropping_non_white=False):
# Calculate the center of the original image
h, w = image.shape
center_y, center_x = h // 2, w // 2
# Calculate the top-left corner of the crop area
start_x = max(center_x - target_width // 2, 0)
start_y = max(center_y - target_height // 2, 0)
# Ensure the crop area does not exceed the image boundaries
end_x = min(start_x + target_width, w)
end_y = min(start_y + target_height, h)
# Crop the image
cropped_image = image[start_y:end_y, start_x:end_x]
if detect_cropping_non_white:
cropping_non_white = False
all_black_pixel_count = np.sum(image < 50)
cropped_black_pixel_count = np.sum(cropped_image < 50)
if cropped_black_pixel_count < all_black_pixel_count:
cropping_non_white = True
# If the cropped image is smaller than the target, pad it to the required size
if cropped_image.shape[0] < target_height or cropped_image.shape[1] < target_width:
pad_height = target_height - cropped_image.shape[0]
pad_width = target_width - cropped_image.shape[1]
cropped_image = cv2.copyMakeBorder(cropped_image, 0, pad_height, 0, pad_width, cv2.BORDER_CONSTANT, value=255) # Using white padding
if detect_cropping_non_white:
if cropping_non_white:
return None
else:
return cropped_image
else:
return cropped_image
def downscale_image(image, block_size=8, black_threshold=50, gray_level=10, return_level=False):
# Calculate the size of the output image
h, w = image.shape
new_h, new_w = h // block_size, w // block_size
# Initialize the output image
downscaled = np.zeros((new_h, new_w), dtype=np.uint8)
image_with_level = np.zeros((new_h, new_w), dtype=np.uint8)
for i in range(0, h, block_size):
for j in range(0, w, block_size):
# Extract the block
block = image[i:i+block_size, j:j+block_size]
# Calculate the proportion of black pixels
black_pixels = np.sum(block < black_threshold)
total_pixels = block_size * block_size
proportion_of_black = black_pixels / total_pixels
discrete_gray_step = 1 / gray_level
if proportion_of_black >= 0.95:
proportion_of_black = 0.94
proportion_of_black = round (proportion_of_black / discrete_gray_step) * discrete_gray_step
# check that gray level is descretize to 0 ~ gray_level-1
try:
assert 0 <= round(proportion_of_black / discrete_gray_step) < gray_level
except:
breakpoint()
# Assign the new grayscale value (inverse proportion if needed)
grayscale_value = int(proportion_of_black * 255)
# Assign to the downscaled image
downscaled[i // block_size, j // block_size] = grayscale_value
image_with_level[i // block_size, j // block_size] = int(proportion_of_black // discrete_gray_step)
if return_level:
return downscaled, image_with_level
else:
return downscaled
PORT = 8008
MODEL_NAME="./axolotl/lora-logo_fix_full_deepseek33b_ds33i_epoch3_lr_0.0002_alpha_512_r_512_merged"
MODEL_NAME="./axolotl/lora-logo_fix_full_deepseek7b_ds33i_lr_0.0002_alpha_512_r_512_merged"
def generate_grid_images(folder):
import matplotlib.patches as patches
import matplotlib.pyplot as plt
num_rows, num_cols = 8,8
fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 12))
fig.tight_layout(pad=0)
# Plot each image with its AST count as a caption
# load all jpg images in the folder
import glob
import os
print(f"load file path")
image_files = glob.glob(os.path.join(folder, "*.jpg"))
print(f"load file path done")
images = []
for idx, image_file in enumerate(image_files):
img = load_img(image_file)
images.append(img)
print(f"Loaded {len(images)} images")
for idx, img in tqdm(enumerate(images)):
if idx >= num_rows * num_cols:
break
row, col = divmod(idx, num_cols)
ax = axes[row, col]
if img is None:
ax.axis('off')
continue
try:
ax.imshow(img, cmap='gray')
except:
breakpoint()
ax.axis('off')
# Hide remaining empty subplots
for idx in range(len(images), num_rows * num_cols):
row, col = divmod(idx, num_cols)
axes[row, col].axis('off')
# convert fig to numpy return image array
fig.canvas.draw()
image_array = np.array(fig.canvas.renderer.buffer_rgba())
plt.close(fig)
return image_array
def llm_call(question_prompt, model_name,
temperature=1, max_tokens=320,
top_p=1, n_samples=64, stop=None):
client = OpenAI(base_url=f"http://localhost:{PORT}/v1", api_key="empty")
response = client.completions.create(
prompt=question_prompt,
model=model_name,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=0,
presence_penalty=0,
n=n_samples,
stop=stop
)
return response
import cv2
def load_img(path):
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
# Threshold the image to create a binary image (white background, black object)
_, thresh = cv2.threshold(img, 240, 255, cv2.THRESH_BINARY)
# Invert the binary image
thresh_inv = cv2.bitwise_not(thresh)
# Find the bounding box of the non-white area
x, y, w, h = cv2.boundingRect(thresh_inv)
# Extract the ROI (region of interest) of the non-white area
roi = img[y:y+h, x:x+w]
# If the ROI is larger than 200x200, resize it
if w > 256 or h > 256:
scale = min(256 / w, 256 / h)
new_w = int(w * scale)
new_h = int(h * scale)
roi = cv2.resize(roi, (new_w, new_h), interpolation=cv2.INTER_AREA)
w, h = new_w, new_h
# Create a new 200x200 white image
centered_img = np.ones((256, 256), dtype=np.uint8) * 255
# Calculate the position to center the ROI in the 200x200 image
start_x = max(0, (256 - w) // 2)
start_y = max(0, (256 - h) // 2)
# Place the ROI in the centered position
centered_img[start_y:start_y+h, start_x:start_x+w] = roi
return centered_img
def run_code(new_folder, counter, code):
import matplotlib
fname = f"{new_folder}/logo_{counter}_.jpg"
counter += 1
code_with_header_and_save= f"""
{LOGO_HEADER}
{code}
turtle.save('{fname}')
"""
try:
func_timeout(3, exec, args=(code_with_header_and_save, {}))
matplotlib.pyplot.close()
# exec(code_with_header_and_save, globals())
except FunctionTimedOut:
print("Timeout")
except Exception as e:
print(e)
def run(img_str):
prompt = PROMPT_TEMPLATE.format(input_struction=INPUT_STRUCTION_TEMPLATE.format(image_str=img_str))
if not MOCK:
response = llm_call(prompt, MODEL_NAME)
print(response)
codes = []
for i, choice in enumerate(response.choices):
print(f"Choice {i}: {choice.text}")
codes.append(choice.text)
else:
codes = MOCK_RESPONSE
gradio_test_images_folder = "gradio_test_images"
import os
os.makedirs(gradio_test_images_folder, exist_ok=True)
counter = 0
# generate a random hash id
import hashlib
import random
random_id = hashlib.md5(str(random.random()).encode()).hexdigest()[0:4]
new_folder = os.path.join(gradio_test_images_folder, random_id)
os.makedirs(new_folder, exist_ok=True)
for code in tqdm(codes):
pass
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import as_completed
with ProcessPoolExecutor() as executor:
futures = [executor.submit(run_code, new_folder, i, code) for i, code in enumerate(codes)]
for future in as_completed(futures):
try:
future.result()
except Exception as exc:
print(f'Generated an exception: {exc}')
# with open("temp.py", 'w') as f:
# f.write(code_with_header_and_save)
# p = subprocess.Popen(["python", "temp.py"], stderr=subprocess.PIPE, stdout=subprocess.PIPE, env=my_env)
# out, errs = p.communicate()
# out, errs, = out.decode(), errs.decode()
# render
print(random_id)
folder_path = f"gradio_test_images/{random_id}"
return folder_path, codes
def test_gen_img_wrapper(_):
return generate_grid_images(f"gradio_test_images/{TEST_FOLDER}")
def int_img_to_str(integer_img):
lines = []
for row in integer_img:
print("".join([str(x) for x in row]))
lines.append("".join([str(x) for x in row]))
image_str = "\n".join(lines)
return image_str
def img_to_code_img(sketchpad_img):
img = sketchpad_img['layers'][0]
image_array = np.array(img)
image_array = 255 - image_array[:,:,3]
# height, width = image_array.shape
# output_size = 512
# block_size = max(height, width) // output_size
# # Create new downscaled image array
# new_image_array = np.zeros((output_size, output_size), dtype=np.uint8)
# # Process each block
# for i in range(output_size):
# for j in range(output_size):
# # Define the block
# block = image_array[i*block_size:(i+1)*block_size, j*block_size:(j+1)*block_size]
# # Calculate the number of pixels set to 255 in the block
# white_pixels = np.sum(block == 255)
# # Set the new pixel value
# if white_pixels >= (block_size * block_size) / 2:
# new_image_array[i, j] = 255
new_image_array= image_array
_, int_img = downscale_image(new_image_array, block_size=16, return_level=True)
if int_img is not None:
img_str = int_img_to_str(int_img)
print(img_str)
folder_path, codes = run(img_str)
generated_grid_img = generate_grid_images(folder_path)
return generated_grid_img
def main():
"""
Sets up and launches the Gradio demo.
"""
import gradio as gr
from gradio import Brush
theme = gr.themes.Default().set(
)
with gr.Blocks(theme=theme) as demo:
gr.Markdown('# Visual Program Synthesis with LLM')
gr.Markdown("""LOGO/Turtle graphics Programming-by-Example problems aims to synthesize a program that generates the given target image, where the program uses drawing library similar to Python Turtle.""")
gr.Markdown("""Here we can draw a target image using the sketchpad, and see what kinds of graphics program LLM generates. To allow the LLM to visually perceive the input image, we convert the image to ASCII strings.""")
gr.Markdown("## Draw logo")
with gr.Column():
canvas = gr.Sketchpad(canvas_size=(512,512), brush=Brush(colors=["black"], default_size=3, color_mode='fixed'))
submit_button = gr.Button("Submit")
output_image = gr.Image(label="output")
submit_button.click(img_to_code_img, inputs=canvas, outputs=output_image)
demo.load(
None,
None,
js="""
() => {
const params = new URLSearchParams(window.location.search);
if (!params.has('__theme')) {
params.set('__theme', 'light');
window.location.search = params.toString();
}
}""",
)
demo.launch(share=True)
if __name__ == "__main__":
# parser = argparse.ArgumentParser()
# parser.add_argument("--host", type=str, default=None)
# parser.add_argument("--port", type=int, default=8001)
# parser.add_argument("--model-url",
# type=str,
# default="http://localhost:8000/generate")
# args = parser.parse_args()
# main()
# run()
# demo = build_demo()
# demo.queue().launch(server_name=args.host,
# server_port=args.port,
# share=True)
main()