Spaces:
Sleeping
Sleeping
from PIL import Image, ImageDraw, ImageFont | |
from dotenv import load_dotenv | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
import gradio as gr | |
import numpy as np | |
import requests | |
import base64 | |
import boto3 | |
import uuid | |
import os | |
import io | |
load_dotenv() | |
AWS_ACCESS_KEY_ID = os.environ.get('AWS_ACCESS_KEY_ID') | |
AWS_SECRET_ACCESS_KEY = os.environ.get('AWS_SECRET_ACCESS_KEY') | |
s3 = boto3.client('s3', | |
aws_access_key_id=AWS_ACCESS_KEY_ID, | |
aws_secret_access_key=AWS_SECRET_ACCESS_KEY) | |
def upload2aws(img_array): | |
image = Image.fromarray(img_array) | |
buffer = io.BytesIO() | |
image.save(buffer, format='JPEG') | |
buffer.seek(0) | |
unique_name = str(uuid.uuid4()) | |
s3.put_object(Bucket='predict-packages', Key=f'images_webapp_counters/{unique_name}.jpg', Body=buffer) | |
return None | |
def send2api(input_img, api_url): | |
buf = io.BytesIO() | |
plt.imsave(buf, input_img, format='jpg') | |
files = {'image': buf.getvalue()} | |
res = requests.post(api_url, files=files) | |
try: | |
res.raise_for_status() | |
if res.status_code != 204: | |
response = res.json() | |
except Exception as e: | |
print(str(e)) | |
return response | |
def displaytext_detclasim(c_cnames, c_scinames, coverage): | |
countings_list = list(c_scinames.items()) | |
countings_list.sort(key = lambda x: x[1], reverse=True) | |
total = 0 | |
for (_,c) in countings_list: | |
total += c | |
text = f'coverage = {coverage}'+'\n\n' | |
text += 'Countings by scientific name:\n' | |
for key,value in countings_list: | |
text += f'{key} = {value}'+'\n' | |
text += '\n\n' | |
text += 'Countings by common name:\n' | |
countings_list = list(c_cnames.items()) | |
countings_list.sort(key = lambda x: x[1], reverse=True) | |
for key,value in countings_list: | |
text += f'{key} = {value}'+'\n' | |
text += '\n' | |
text += f'total = {total}'+'\n' | |
return text | |
def displaytext_yolocounter(countings, coverage): | |
countings_list = list(countings.items()) | |
countings_list.sort(key = lambda x: x[1], reverse=True) | |
total = 0 | |
for (y_class,c) in countings_list: | |
total += c | |
text = f'coverage = {coverage}'+'\n\n' | |
for key,value in countings_list: | |
text += f'{key} = {value}'+'\n' | |
text += '\n' | |
text += f'total = {total}'+'\n' | |
return text | |
def display_detectionsandcountings_directcounter(img_array, countings, prob_th=0, cth = 0): | |
img = Image.fromarray(img_array) | |
img1 = ImageDraw.Draw(img) | |
h, w = img.size | |
ratio = h/4000 | |
countings_list = list(countings.items()) | |
countings_list.sort(key = lambda x: x[1], reverse=True) | |
yi=int(20*ratio) | |
total = 0 | |
for (y_class,c) in countings_list: | |
if c > cth: | |
img1.text((int(50*ratio), yi), "# {} = {}".format(y_class, c), fill='red') | |
yi += int(100*ratio) | |
total += c | |
yi += int(100*ratio) | |
img1.text((int(50*ratio), yi), "# {} = {}".format('total', total), fill='red') | |
text = '' | |
for key,value in countings_list: | |
text += f'{key} = {value}'+'\n' | |
text += '\n' | |
text += f'total = {total}'+'\n' | |
return img, text | |
def testing_countingid(input_img): | |
upload2aws(input_img) | |
api_url = 'http://countingid-test.us-east-1.elasticbeanstalk.com/predict' | |
response = send2api(input_img, api_url) | |
c_cnames = response['countings_cnames'] | |
c_scinames = response['countings_scinames'] | |
coverage = response['coverage'] | |
detections = response['detections'] | |
img_out = response['img_out'] | |
img = Image.open(BytesIO(base64.b64decode(img_out))) | |
text = displaytext_detclasim(c_cnames, c_scinames, coverage) | |
return img, text | |
def testing_yolocounter(input_img): | |
api_url = 'http://yolocounter-test.us-east-1.elasticbeanstalk.com/predict' | |
response = send2api(input_img, api_url) | |
countings = response['countings_scinames'] | |
coverage = response['coverage'] | |
detections = response['detections'] | |
img_out = response['img_out'] | |
img = Image.open(BytesIO(base64.b64decode(img_out))) | |
text = displaytext_yolocounter(countings, coverage) | |
return img, text | |
def testing_directcounter(input_img): | |
api_url = 'http://directcounter-test.us-east-1.elasticbeanstalk.com/predict' | |
response = send2api(input_img, api_url) | |
countings = response['countings_scinames'] | |
img, text = display_detectionsandcountings_directcounter(input_img, countings, prob_th=0, cth = 0) | |
return img, text | |
with gr.Blocks() as demo: | |
gr.Markdown("Submit an image with insects in a trap") | |
with gr.Tab("Species & Common Name Count"): | |
with gr.Row(): | |
input1 = gr.Image() | |
output1 =[gr.Image().style(height=500, width=500), gr.Textbox(lines=20)] | |
button1 = gr.Button("Submit") | |
button1.click(testing_countingid, input1, output1) | |
with gr.Tab("Simplified Scientific Name Count"): | |
with gr.Row(): | |
#input2 = gr.Image() | |
output2 =[gr.Image().style(height=500, width=500), gr.Textbox(lines=20)] | |
#button2 = gr.Button("Submit") | |
button1.click(testing_yolocounter, input1, output2) | |
""" with gr.Tab("Direct insect counter"): | |
with gr.Row(): | |
#input3 = gr.Image() | |
output3 =[gr.Image().style(height=500, width=500), gr.Textbox(lines=20)] | |
#button3 = gr.Button("Submit") | |
button1.click(testing_directcounter, input1, output3) | |
""" | |
demo.launch() |