Spaces:
Runtime error
Runtime error
import gradio as gr | |
import csv | |
from io import StringIO | |
from PIL import Image | |
import numpy as np | |
import base64 | |
# Define the annotation types | |
ANNOTATION_TYPES = ['rect', 'circle'] | |
# Define the Annotation class | |
class Annotation: | |
def __init__(self, x, y, width, height, annotation_type): | |
self.x = x | |
self.y = y | |
self.width = width | |
self.height = height | |
self.type = annotation_type | |
# Define the Gradio interface | |
def annotate_images(images): | |
# Define the canvas size | |
canvas_size = (600, 600) | |
# Define the initial state | |
state = { | |
'image': None, | |
'annotations': [], | |
'annotation_type': ANNOTATION_TYPES[0], | |
'start_point': None | |
} | |
# Define the canvas drawing function | |
def draw_canvas(canvas, image_data, annotations): | |
# Convert the image data to a PIL Image object | |
image = Image.fromarray(image_data) | |
# Resize the image to fit the canvas | |
image = image.resize(canvas_size) | |
# Draw the image on the canvas | |
canvas.draw_image(image, (canvas_size[0]/2, canvas_size[1]/2)) | |
# Draw the annotations on the canvas | |
for annotation in annotations: | |
x, y, width, height = annotation.x, annotation.y, annotation.width, annotation.height | |
if annotation.type == 'rect': | |
canvas.draw_rect(x, y, width, height, stroke_color='red') | |
elif annotation.type == 'circle': | |
radius = np.sqrt(np.power(width, 2) + np.power(height, 2)) / 2 | |
center_x, center_y = x + width / 2, y + height / 2 | |
canvas.draw_circle(center_x, center_y, radius, stroke_color='red') | |
# Define the canvas mousedown event handler | |
def canvas_mousedown(canvas, x, y): | |
state['start_point'] = (x, y) | |
# Define the canvas mousemove event handler | |
def canvas_mousemove(canvas, x, y): | |
if state['start_point'] is not None: | |
start_x, start_y = state['start_point'] | |
end_x, end_y = x, y | |
annotation_type = state['annotation_type'] | |
draw_annotation(canvas, start_x, start_y, end_x, end_y, annotation_type) | |
# Define the canvas mouseup event handler | |
def canvas_mouseup(canvas, x, y): | |
if state['start_point'] is not None: | |
start_x, start_y = state['start_point'] | |
end_x, end_y = x, y | |
annotation_type = state['annotation_type'] | |
add_annotation(start_x, start_y, end_x, end_y, annotation_type) | |
state['start_point'] = None | |
# Define the add annotation function | |
def add_annotation(start_x, start_y, end_x, end_y, annotation_type): | |
# Calculate the width and height of the annotation | |
width = np.abs(start_x - end_x) | |
height = np.abs(start_y - end_y) | |
# Create the annotation object | |
annotation = Annotation(start_x, start_y, width, height, annotation_type) | |
# Add the annotation to the array | |
state['annotations'].append(annotation) | |
# Redraw the canvas | |
draw_canvas(canvas, state['image'], state['annotations']) | |
# | |
# Define the draw annotation function | |
def draw_annotation(canvas, start_x, start_y, end_x, end_y, annotation_type): | |
canvas.clear() | |
draw_canvas(canvas, state['image'], state['annotations']) | |
width = np.abs(start_x - end_x) | |
height = np.abs(start_y - end_y) | |
if annotation_type == 'rect': | |
canvas.draw_rect(start_x, start_y, width, height, stroke_color='red') | |
elif annotation_type == 'circle': | |
radius = np.sqrt(np.power(width, 2) + np.power(height, 2)) / 2 | |
center_x, center_y = start_x + width / 2, start_y + height / 2 | |
canvas.draw_circle(center_x, center_y, radius, stroke_color='red') | |
# Define the annotation type dropdown event handler | |
def annotation_type_changed(value): | |
state['annotation_type'] = value | |
# Define the download annotations button click event handler | |
def download_annotations_clicked(): | |
# Define the csv headers | |
headers = ['x', 'y', 'width', 'height', 'type'] | |
# Define the csv data | |
rows = [[str(annotation.x), str(annotation.y), str(annotation.width), str(annotation.height), annotation.type] | |
for annotation in state['annotations']] | |
# Create the csv string | |
csv_string = StringIO() | |
csv_writer = csv.writer(csv_string) | |
csv_writer.writerow(headers) | |
for row in rows: | |
csv_writer.writerow(row) | |
# Download the csv file | |
b64_csv = base64.b64encode(csv_string.getvalue().encode()).decode() | |
href = f'data:text/csv;base64,{b64_csv}' | |
download_link = f'<a href="{href}" download="annotations.csv">Download Annotations CSV</a>' | |
gr.Interface.show(download_link) | |
# Define the interface components | |
image = gr.inputs.Image(label='Image') | |
annotation_type = gr.inputs.Dropdown(ANNOTATION_TYPES, label='Annotation Type', default=ANNOTATION_TYPES[0], onchange=annotation_type_changed) | |
download_annotations = gr.outputs.Button(label='Download Annotations', type='button', onclick=download_annotations_clicked) | |
canvas = gr.outputs.Canvas(draw_event_handlers={ | |
'mousedown': canvas_mousedown, | |
'mousemove': canvas_mousemove, | |
'mouseup': canvas_mouseup | |
}) | |
# Define the interface function | |
def annotate_images(images): | |
state['image'] = images[0] | |
draw_canvas(canvas, state['image'], state['annotations']) | |
return canvas, annotation_type, download_annotations | |
# Create the interface | |
interface = gr.Interface(annotate_images, inputs=image, outputs=[canvas, annotation_type, download_annotations], capture_session=True) | |
return interface | |