import gradio as gr import csv from io import StringIO from PIL import Image import numpy as np import base64 # Define the annotation types ANNOTATION_TYPES = ['rect', 'circle'] # Define the Annotation class class Annotation: def __init__(self, x, y, width, height, annotation_type): self.x = x self.y = y self.width = width self.height = height self.type = annotation_type # Define the Gradio interface def annotate_images(images): # Define the canvas size canvas_size = (600, 600) # Define the initial state state = { 'image': None, 'annotations': [], 'annotation_type': ANNOTATION_TYPES[0], 'start_point': None } # Define the canvas drawing function def draw_canvas(canvas, image_data, annotations): # Convert the image data to a PIL Image object image = Image.fromarray(image_data) # Resize the image to fit the canvas image = image.resize(canvas_size) # Draw the image on the canvas canvas.draw_image(image, (canvas_size[0]/2, canvas_size[1]/2)) # Draw the annotations on the canvas for annotation in annotations: x, y, width, height = annotation.x, annotation.y, annotation.width, annotation.height if annotation.type == 'rect': canvas.draw_rect(x, y, width, height, stroke_color='red') elif annotation.type == 'circle': radius = np.sqrt(np.power(width, 2) + np.power(height, 2)) / 2 center_x, center_y = x + width / 2, y + height / 2 canvas.draw_circle(center_x, center_y, radius, stroke_color='red') # Define the canvas mousedown event handler def canvas_mousedown(canvas, x, y): state['start_point'] = (x, y) # Define the canvas mousemove event handler def canvas_mousemove(canvas, x, y): if state['start_point'] is not None: start_x, start_y = state['start_point'] end_x, end_y = x, y annotation_type = state['annotation_type'] draw_annotation(canvas, start_x, start_y, end_x, end_y, annotation_type) # Define the canvas mouseup event handler def canvas_mouseup(canvas, x, y): if state['start_point'] is not None: start_x, start_y = state['start_point'] end_x, end_y = x, y annotation_type = state['annotation_type'] add_annotation(start_x, start_y, end_x, end_y, annotation_type) state['start_point'] = None # Define the add annotation function def add_annotation(start_x, start_y, end_x, end_y, annotation_type): # Calculate the width and height of the annotation width = np.abs(start_x - end_x) height = np.abs(start_y - end_y) # Create the annotation object annotation = Annotation(start_x, start_y, width, height, annotation_type) # Add the annotation to the array state['annotations'].append(annotation) # Redraw the canvas draw_canvas(canvas, state['image'], state['annotations']) # # Define the draw annotation function def draw_annotation(canvas, start_x, start_y, end_x, end_y, annotation_type): canvas.clear() draw_canvas(canvas, state['image'], state['annotations']) width = np.abs(start_x - end_x) height = np.abs(start_y - end_y) if annotation_type == 'rect': canvas.draw_rect(start_x, start_y, width, height, stroke_color='red') elif annotation_type == 'circle': radius = np.sqrt(np.power(width, 2) + np.power(height, 2)) / 2 center_x, center_y = start_x + width / 2, start_y + height / 2 canvas.draw_circle(center_x, center_y, radius, stroke_color='red') # Define the annotation type dropdown event handler def annotation_type_changed(value): state['annotation_type'] = value # Define the download annotations button click event handler def download_annotations_clicked(): # Define the csv headers headers = ['x', 'y', 'width', 'height', 'type'] # Define the csv data rows = [[str(annotation.x), str(annotation.y), str(annotation.width), str(annotation.height), annotation.type] for annotation in state['annotations']] # Create the csv string csv_string = StringIO() csv_writer = csv.writer(csv_string) csv_writer.writerow(headers) for row in rows: csv_writer.writerow(row) # Download the csv file b64_csv = base64.b64encode(csv_string.getvalue().encode()).decode() href = f'data:text/csv;base64,{b64_csv}' download_link = f'Download Annotations CSV' gr.Interface.show(download_link) # Define the interface components image = gr.inputs.Image(label='Image') annotation_type = gr.inputs.Dropdown(ANNOTATION_TYPES, label='Annotation Type', default=ANNOTATION_TYPES[0], onchange=annotation_type_changed) download_annotations = gr.outputs.Button(label='Download Annotations', type='button', onclick=download_annotations_clicked) canvas = gr.outputs.Canvas(draw_event_handlers={ 'mousedown': canvas_mousedown, 'mousemove': canvas_mousemove, 'mouseup': canvas_mouseup }) # Define the interface function def annotate_images(images): state['image'] = images[0] draw_canvas(canvas, state['image'], state['annotations']) return canvas, annotation_type, download_annotations # Create the interface interface = gr.Interface(annotate_images, inputs=image, outputs=[canvas, annotation_type, download_annotations], capture_session=True) return interface