Spaces:
Running
Running
import os | |
from label_studio_converter import brush | |
from typing import List, Dict, Optional | |
from uuid import uuid4 | |
from sam_predictor import SAMPredictor | |
from label_studio_ml.model import LabelStudioMLBase | |
SAM_CHOICE = os.environ.get("SAM_CHOICE", "MobileSAM") # other option is just SAM | |
PREDICTOR = SAMPredictor(SAM_CHOICE) | |
class SamMLBackend(LabelStudioMLBase): | |
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]: | |
""" Returns the predicted mask for a smart keypoint that has been placed.""" | |
from_name, to_name, value = self.get_first_tag_occurence('BrushLabels', 'Image') | |
if not context or not context.get('result'): | |
# if there is no context, no interaction has happened yet | |
return [] | |
image_width = context['result'][0]['original_width'] | |
image_height = context['result'][0]['original_height'] | |
# collect context information | |
point_coords = [] | |
point_labels = [] | |
input_box = None | |
selected_label = None | |
for ctx in context['result']: | |
x = ctx['value']['x'] * image_width / 100 | |
y = ctx['value']['y'] * image_height / 100 | |
ctx_type = ctx['type'] | |
selected_label = ctx['value'][ctx_type][0] | |
if ctx_type == 'keypointlabels': | |
point_labels.append(int(ctx['is_positive'])) | |
point_coords.append([int(x), int(y)]) | |
elif ctx_type == 'rectanglelabels': | |
box_width = ctx['value']['width'] * image_width / 100 | |
box_height = ctx['value']['height'] * image_height / 100 | |
input_box = [int(x), int(y), int(box_width + x), int(box_height + y)] | |
print(f'Point coords are {point_coords}, point labels are {point_labels}, input box is {input_box}') | |
img_path = tasks[0]['data'][value] | |
predictor_results = PREDICTOR.predict( | |
img_path=img_path, | |
point_coords=point_coords or None, | |
point_labels=point_labels or None, | |
input_box=input_box | |
) | |
predictions = self.get_results( | |
masks=predictor_results['masks'], | |
probs=predictor_results['probs'], | |
width=image_width, | |
height=image_height, | |
from_name=from_name, | |
to_name=to_name, | |
label=selected_label) | |
return predictions | |
def get_results(self, masks, probs, width, height, from_name, to_name, label): | |
results = [] | |
for mask, prob in zip(masks, probs): | |
# creates a random ID for your label everytime so no chance for errors | |
label_id = str(uuid4())[:4] | |
# converting the mask from the model to RLE format which is usable in Label Studio | |
mask = mask * 255 | |
rle = brush.mask2rle(mask) | |
results.append({ | |
'id': label_id, | |
'from_name': from_name, | |
'to_name': to_name, | |
'original_width': width, | |
'original_height': height, | |
'image_rotation': 0, | |
'value': { | |
'format': 'rle', | |
'rle': rle, | |
'brushlabels': [label], | |
}, | |
'score': prob, | |
'type': 'brushlabels', | |
'readonly': False | |
}) | |
return [{ | |
'result': results, | |
'model_version': PREDICTOR.model_name | |
}] | |
if __name__ == '__main__': | |
# test the model | |
model = SamMLBackend() | |
model.use_label_config(''' | |
<View> | |
<Image name="image" value="$image" zoom="true"/> | |
<BrushLabels name="tag" toName="image"> | |
<Label value="Banana" background="#FF0000"/> | |
<Label value="Orange" background="#0d14d3"/> | |
</BrushLabels> | |
<KeyPointLabels name="tag2" toName="image" smart="true" > | |
<Label value="Banana" background="#000000" showInline="true"/> | |
<Label value="Orange" background="#000000" showInline="true"/> | |
</KeyPointLabels> | |
<RectangleLabels name="tag3" toName="image" > | |
<Label value="Banana" background="#000000" showInline="true"/> | |
<Label value="Orange" background="#000000" showInline="true"/> | |
</RectangleLabels> | |
</View> | |
''') | |
results = model.predict( | |
tasks=[{ | |
'data': { | |
'image': 'https://s3.amazonaws.com/htx-pub/datasets/images/125245483_152578129892066_7843809718842085333_n.jpg' | |
}}], | |
context={ | |
'result': [{ | |
'original_width': 1080, | |
'original_height': 1080, | |
'image_rotation': 0, | |
'value': { | |
'x': 49.441786283891545, | |
'y': 59.96810207336522, | |
'width': 0.3189792663476874, | |
'labels': ['Banana'], | |
'keypointlabels': ['Banana'] | |
}, | |
'is_positive': True, | |
'id': 'fBWv1t0S2L', | |
'from_name': 'tag2', | |
'to_name': 'image', | |
'type': 'keypointlabels', | |
'origin': 'manual' | |
}]} | |
) | |
import json | |
results[0]['result'][0]['value']['rle'] = f'...{len(results[0]["result"][0]["value"]["rle"])} integers...' | |
print(json.dumps(results, indent=2)) |