sygma-damage-annotation / functions.py
ychafiqui's picture
added ability to filter images by damaged parts
c6c388a
raw
history blame
2.26 kB
import boto3
from PIL import Image
import pandas as pd
import streamlit as st
import random
import io
s3_client = boto3.client('s3',
aws_access_key_id=st.secrets["aws_access_key_id"],
aws_secret_access_key=st.secrets["aws_secret_access_key"],
region_name='eu-west-3')
bucket_name = "sygma-global-data-storage"
folder = "car-damage-detection/scrappedImages/"
csv_folder = "car-damage-detection/CSVs/"
s3_df_path = csv_folder + "70k_old_annotations_fixed.csv"
response = s3_client.get_object(Bucket=bucket_name, Key=s3_df_path)
# df = pd.read_csv("CSVs/70k_old_annotations_fixed.csv", low_memory=False)
with io.BytesIO(response['Body'].read()) as bio:
df = pd.read_csv(bio, low_memory=False)
df = df[df['s3_available'] == True]
def get_random_image(parts_filter=False):
not_validated_imgs = df[df["validated"] == False]["img_name"].tolist()
if parts_filter:
# get rows where all selected parts are damaged (> 0)
filtered_imgs = df[(df[parts_filter] > 0).all(axis=1)]["img_name"].tolist()
not_validated_imgs = list(set(not_validated_imgs) & set(filtered_imgs))
if len(not_validated_imgs) == 0:
return None, None
image_name = random.choice(not_validated_imgs)
s3_image_path = folder + image_name
try:
response = s3_client.get_object(Bucket=bucket_name, Key=s3_image_path)
image = Image.open(io.BytesIO(response['Body'].read())).resize((1000, 800))
return image, image_name
except:
return get_random_image()
def get_img_damages(img_name):
img_row = df.loc[df["img_name"] == img_name]
damages = img_row.iloc[0, 6:].to_dict()
return damages
def process_image(img_name, annotator_name, is_car, skip, rotation, damaged_parts):
df.loc[df["img_name"] == img_name, "annotator_name"] = annotator_name
df.loc[df["img_name"] == img_name, "is_car"] = is_car
df.loc[df["img_name"] == img_name, "rotation"] = rotation
if not skip:
df.loc[df["img_name"] == img_name, damaged_parts.keys()] = damaged_parts.values()
df.loc[df["img_name"] == img_name, "validated"] = not skip
# df.to_csv("CSVs/70k_old_annotations_fixed.csv", index=False)
s3_client.put_object(Bucket=bucket_name, Key=s3_df_path, Body=df.to_csv(index=False))