File size: 3,875 Bytes
05f9833
714cc17
091113e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05f9833
 
 
 
 
8dd7af4
9eb5891
 
 
 
0b16722
9eb5891
ca76df5
0b16722
8dd7af4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd5a194
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import torch
from io import BytesIO

import cv2
import gradio as gr
import numpy as np
import requests
from PIL import Image


from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training.utils.visualization.detection import draw_bbox


# Initialize your pose estimation model
yolo_nas_pose = models.get("best.pt",
                           num_classes=1,
                           checkpoint_path="./best.pt")

def process_and_predict(url=None,
                        image=None,
                        confidence=0.5,
                        iou=0.5):
    # If a URL is provided, use it directly for prediction
    if url is not None and url.strip() != "":
        response = requests.get(url)
        image = Image.open(BytesIO(response.content))
        image = np.array(image)
        result = yolo_nas_pose.predict(image, conf=confidence,iou=iou)
    # If a file is uploaded, read it, convert it to a numpy array and use it for prediction
    elif image is not None:
        result = yolo_nas_pose.predict(image, conf=confidence,iou=iou)
    else:
        return None  # If no input is provided, return None

    # Extract prediction data
    image_prediction = result._images_prediction_lst[0]

    pose_data = image_prediction.prediction

    # Visualize the prediction
    output_image = PoseVisualization.draw_poses(
        image=image_prediction.image,
        poses=pose_data.poses,
        boxes=pose_data.bboxes_xyxy,
        scores=pose_data.scores,
        is_crowd=None,
        edge_links=pose_data.edge_links,
        edge_colors=pose_data.edge_colors,
        keypoint_colors=pose_data.keypoint_colors,
        joint_thickness=2,
        box_thickness=2,
        keypoint_radius=5
    )

    blank_image = np.zeros_like(image_prediction.image)

    skeleton_image = PoseVisualization.draw_poses(
    image=blank_image,
    poses=pose_data.poses,
    boxes=pose_data.bboxes_xyxy,
    scores=pose_data.scores,
    is_crowd=None,
    edge_links=pose_data.edge_links,
    edge_colors=pose_data.edge_colors,
    keypoint_colors=pose_data.keypoint_colors,
    joint_thickness=2,
    box_thickness=2,
    keypoint_radius=5
)

    return output_image, skeleton_image

def greet(name):
    return "Hello " + name + "!!"

demo = gr.Interface(fn=greet, inputs="text", outputs="text")

from urllib.request import urlretrieve

# get image examples from github
urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-1450-_jpg.jpg?raw=true", "clip2_-1450-_jpg.jpg") # make sure to use "copy image address when copying image from Github"
urlretrieve("https://github.com/SamDaaLamb/ValorantTracker/blob/main/clip2_-539-_jpg.jpg?raw=true", "clip2_-539-_jpg.jpg")
examples = [ # need to manually delete cache everytime new examples are added
    ["clip2_-1450-_jpg.jpg"], 
    ["clip2_-539-_jpg.jpg"]]


# define app features and run
title = "SpecLab Demo"
description = "<p style='text-align: center'>Gradio demo for an ASPP model architecture trained on the SpecLab dataset. To use it, simply add your image, or click one of the examples to load them. Since this demo is run on CPU only, please allow additional time for processing. </p>"
article = "<p style='text-align: center'><a href='https://github.com/Nano1337/SpecLab'>Github Repo</a></p>"
css = "#0 {object-fit: contain;} #1 {object-fit: contain;}"
demo = gr.Interface(fn=speclab, 
                    title=title, 
                    description=description,
                    article=article,
                    inputs=gr.Image(elem_id=0, show_label=False), 
                    outputs=gr.Image(elem_id=1, show_label=False),
                    css=css, 
                    examples=examples, 
                    cache_examples=True,
                    allow_flagging='never')
demo.launch()