Spaces:
Sleeping
Sleeping
import gradio as gr | |
from definitions import SIZES_PER_CATEGORY | |
from predict import Predict | |
predictor = Predict() | |
def update_main_categories(main): | |
if isinstance(SIZES_PER_CATEGORY[main], list): | |
return ( | |
main, | |
"", | |
"", | |
gr.Dropdown(list(SIZES_PER_CATEGORY[main]), label="Size", interactive=True), | |
) | |
return ( | |
main, | |
"", | |
"", | |
gr.Dropdown( | |
list(SIZES_PER_CATEGORY[main].keys()), | |
label="Size", | |
interactive=False, | |
), | |
) | |
def update_sub_categories(main, sub): | |
if isinstance(SIZES_PER_CATEGORY[main][sub], list): | |
return ( | |
main, | |
sub, | |
"", | |
gr.Dropdown( | |
list(SIZES_PER_CATEGORY[main][sub]), label="Size", interactive=True | |
), | |
) | |
return ( | |
main, | |
sub, | |
"", | |
gr.Dropdown( | |
list(SIZES_PER_CATEGORY[main][sub].keys()), | |
label="Size", | |
interactive=False, | |
), | |
) | |
def update_sub_sub_categories(main, sub, sub_sub): | |
if isinstance(SIZES_PER_CATEGORY[main][sub][sub_sub], list): | |
return ( | |
main, | |
sub, | |
sub_sub, | |
gr.Dropdown( | |
list(SIZES_PER_CATEGORY[main][sub][sub_sub]), | |
label="Size", | |
interactive=True, | |
), | |
) | |
return ( | |
main, | |
sub, | |
sub_sub, | |
gr.Dropdown( | |
list(SIZES_PER_CATEGORY[main][sub][sub_sub].keys()), | |
label="Size", | |
interactive=False, | |
), | |
) | |
def submit_form( | |
title, | |
description, | |
condition, | |
main_category, | |
sub_category, | |
sub_sub_category, | |
size, | |
color, | |
hashtags, | |
designer_names, | |
followerno, | |
# user_score, | |
): | |
print( | |
f"Title: {title}, Description: {description}, Condition: {condition}, hashtags: {hashtags}, designer: {designer_names}, Main Category: {main_category}, Sub:{sub_category}, sub_sub:{sub_sub_category}, Size: {size}, Color: {color}, Followers: {followerno}" | |
) | |
# Create dictionary with all the input fields | |
input_dict = { | |
"title": title, | |
"description": description, | |
"condition": condition, | |
"category_path": ".".join(filter(None, [main_category, sub_category, sub_sub_category])), | |
"size": size, | |
"color": color, | |
"hashtags": hashtags, | |
"designer_names": designer_names, | |
"followerno": followerno, | |
# "user_score": user_score, | |
} | |
prediction = predictor.predict(input_dict) | |
# Round prediction | |
return int(prediction) | |
def add_hashtag(hashtags, new_hashtag): | |
if not new_hashtag: | |
return hashtags | |
current = hashtags.split() if hashtags else [] | |
if new_hashtag not in current: | |
current.append(new_hashtag) | |
return " ".join(current) | |
def remove_hashtag(hashtags, hashtag_to_remove): | |
if not hashtags: | |
return "" | |
current = hashtags.split() | |
try: | |
current.remove(hashtag_to_remove.strip()) | |
except ValueError: | |
pass | |
return " ".join(current) | |
def add_designer(designers, new_designer): | |
if not new_designer: | |
return designers | |
current = designers.split() if designers else [] | |
if new_designer not in current: | |
current.append(new_designer) | |
return " ".join(current) | |
def remove_designer(designers, designer_to_remove): | |
if not designers: | |
return "" | |
current = designers.split() | |
try: | |
current.remove(designer_to_remove.strip()) | |
except ValueError: | |
pass | |
return " ".join(current) | |
with gr.Blocks(theme="argilla/argilla-theme", title="Grailed Price Predictor") as demo: | |
global size_text_box | |
gr.HTML( | |
""" | |
<h1 style="text-align: center;">Grailed Price Predictor</h1> | |
<p>Welcome to our Grailed Price Prediction Model! This project, developed for the ID2223 course, | |
demonstrates how modern machine learning systems can be effectively implemented using feature stores and serverless computing. | |
To get started, simply fill out the form and click "Submit."</p> | |
""" | |
) | |
main_category = gr.State("") | |
sub_category = gr.State("") | |
sub_sub_category = gr.State("") | |
with gr.Group(): | |
title_text_box = gr.Textbox(label="Title") | |
description_text_area = gr.TextArea(label="Description") | |
condition_dropdown = gr.Dropdown( | |
["is_new", "is_gently_used", "is_used", "is_worn"], | |
label="Condition", | |
info="The condition of your item.", | |
interactive=True, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
def show_categories(main, sub, sub_sub): | |
global size_text_box | |
main_dropdown = gr.Dropdown( | |
list(SIZES_PER_CATEGORY.keys()), | |
label="Main Category", | |
value=main, | |
interactive=True, | |
) | |
main_dropdown.change( | |
update_main_categories, | |
main_dropdown, | |
[main_category, sub_category, sub_sub_category, size_text_box], | |
) | |
if not main or isinstance(SIZES_PER_CATEGORY[main], list): | |
return | |
sub_dropdown = gr.Dropdown( | |
list(SIZES_PER_CATEGORY[main].keys()), | |
label="Sub Category", | |
interactive=True, | |
value=sub, | |
) | |
sub_dropdown.change( | |
update_sub_categories, | |
[main_dropdown, sub_dropdown], | |
[main_category, sub_category, sub_sub_category, size_text_box], | |
) | |
if not sub or isinstance(SIZES_PER_CATEGORY[main][sub], list): | |
return | |
sub_sub_dropdown = gr.Dropdown( | |
list(SIZES_PER_CATEGORY[main][sub].keys()), | |
label="Sub Sub Category", | |
interactive=True, | |
value=sub_sub, | |
) | |
sub_sub_dropdown.change( | |
update_sub_sub_categories, | |
[main_dropdown, sub_dropdown, sub_sub_dropdown], | |
[main_category, sub_category, sub_sub_category, size_text_box], | |
) | |
size_text_box = gr.Dropdown( | |
label="Size", | |
interactive=False, | |
) | |
color_text_box = gr.Textbox(label="Color") | |
with gr.Row(): | |
with gr.Column(): | |
hashtags_state = gr.State("") | |
new_hashtag_input = gr.Textbox(label="Add Hashtag") | |
with gr.Row(): | |
add_hashtag_btn = gr.Button("Add Hashtag") | |
clear_hashtags_btn = gr.Button("Clear All Hashtags") | |
hashtags_display = gr.Dataframe( | |
headers=["Hashtag"], | |
interactive=False, | |
label="Current Hashtags" | |
) | |
def update_hashtags_display(hashtags): | |
if not hashtags: | |
return [] | |
return [[tag] for tag in hashtags.split()] | |
add_hashtag_btn.click( | |
add_hashtag, | |
inputs=[hashtags_state, new_hashtag_input], | |
outputs=hashtags_state | |
).then( | |
update_hashtags_display, | |
inputs=[hashtags_state], | |
outputs=hashtags_display | |
).then( | |
lambda: "", | |
outputs=new_hashtag_input | |
) | |
# Add clear functionality | |
clear_hashtags_btn.click( | |
lambda: "", # Clear the state | |
outputs=hashtags_state | |
).then( | |
lambda: [], # Clear the display | |
outputs=hashtags_display | |
) | |
with gr.Column(): | |
designers_state = gr.State("") | |
new_designer_input = gr.Textbox(label="Add Designer") | |
with gr.Row(): | |
add_designer_btn = gr.Button("Add Designer") | |
clear_designers_btn = gr.Button("Clear All Designers") | |
designers_display = gr.Dataframe( | |
headers=["Designer"], | |
interactive=False, | |
label="Current Designers" | |
) | |
def update_designers_display(designers): | |
if not designers: | |
return [] | |
return [[designer] for designer in designers.split()] | |
add_designer_btn.click( | |
add_designer, | |
inputs=[designers_state, new_designer_input], | |
outputs=designers_state | |
).then( | |
update_designers_display, | |
inputs=[designers_state], | |
outputs=designers_display | |
).then( | |
lambda: "", | |
outputs=new_designer_input | |
) | |
# Add clear functionality | |
clear_designers_btn.click( | |
lambda: "", # Clear the state | |
outputs=designers_state | |
).then( | |
lambda: [], # Clear the display | |
outputs=designers_display | |
) | |
with gr.Row(): | |
followernoNumber = gr.Number(label="Number of Followers") | |
# userScoreNumber = gr.Number(label="Your User Score") | |
submitButton = gr.Button("Submit") | |
# Add output component below submit button | |
prediction_output = gr.Textbox(label="Predicted Price (in USD)", interactive=False) | |
submitButton.click( | |
submit_form, | |
inputs=[ | |
title_text_box, | |
description_text_area, | |
condition_dropdown, | |
main_category, | |
sub_category, | |
sub_sub_category, | |
size_text_box, | |
color_text_box, | |
hashtags_state, | |
designers_state, | |
followernoNumber, | |
# userScoreNumber, | |
], | |
outputs=prediction_output | |
) | |
if __name__ == "__main__": | |
demo.launch() | |