Spaces:
Runtime error
Runtime error
import gradio as gr | |
from joblib import load | |
import os | |
import pandas as pd | |
# load the model | |
def load_model(path=""): | |
if os.path.exists(path): | |
model_dict = load(path) | |
return model_dict | |
else: | |
print("Model not found") | |
model_dict = load_model('diamond_price.joblib') | |
# prediction function | |
def diamond_price_regressor(caret, depth, table, x, y, z, cut, color, clarity): | |
model = model_dict['model'] | |
target_converter = model_dict['quantile'] | |
input_frame = pd.DataFrame({ | |
'carat': [caret], | |
'cut': [cut], | |
'color': [color], | |
'clarity': [clarity], | |
'depth': [depth], | |
'table': [table], | |
'x': [x], | |
'y': [y], | |
'z': [z] | |
}) | |
print(input_frame) | |
pred = model.predict(input_frame) | |
pred = target_converter.inverse_transform(pred.reshape(-1, 1)) | |
print(pred) | |
return f'Approx price is ${pred[0][0]:.2f}' | |
cut_choices = ['Ideal', 'Premium', 'Good', 'Very Good', 'Fair'] | |
color_choices = ['E', 'I', 'J', 'H', 'F', 'G', 'D'] | |
clarity_choices = ['SI2', 'SI1', 'VS1', 'VS2', 'VVS2', 'VVS1', 'I1', 'IF'] | |
cut_choices = ['Ideal', 'Premium', 'Good', 'Very Good', 'Fair'] | |
color_choices = ['E', 'I', 'J', 'H', 'F', 'G', 'D'] | |
clarity_choices = ['SI2', 'SI1', 'VS1', 'VS2', 'VVS2', 'VVS1', 'I1', 'IF'] | |
with gr.Blocks() as ui: | |
gr.HTML(""" | |
<h1>💎 Diamond Price Predictor 💎</h1> | |
<h3>Find out how much that rock is worth!</h3> | |
""") | |
with gr.Row(): | |
with gr.Column(elem_classes="diamond-box"): | |
gr.Markdown("### Physical Characteristics") | |
carat = gr.Slider(minimum=0, maximum=10, step=.01, value=.7, label="Carat", info="1 carat = 0.2 grams") | |
with gr.Row(): | |
x = gr.Slider(minimum=0, maximum=20, step=.01, value=5, label="Length (mm)", info="x") | |
y = gr.Slider(minimum=0, maximum=20, step=.01, value=5, label="Width (mm)", info="y") | |
z = gr.Slider(minimum=0, maximum=20, step=.01, value=3.5, label="Height (mm)", info="z") | |
depth = gr.Slider(minimum=0, maximum=100, step=.01, value=61, label="Depth %", info="Total depth percentage") | |
table = gr.Slider(minimum=0, maximum=100, step=.01, value=57, label="Table %", info="Width of top relative to widest point") | |
with gr.Column(elem_classes="diamond-box"): | |
gr.Markdown("### Quality Ratings") | |
cut = gr.Radio(cut_choices, label="Cut Quality", value="Ideal", info="Quality of the cut") | |
color = gr.Dropdown(color_choices, label="Color Grade", value="E", info="D is best (colorless), J is worst") | |
clarity = gr.Dropdown(clarity_choices, label="Clarity Grade", value="SI2", info="IF (flawless) is best, I1 is worst") | |
result = gr.Textbox(label="Estimated Price", lines=1) | |
predict_btn = gr.Button("💰 Calculate Price 💰", variant="primary") | |
with gr.Row(elem_classes="diamond-box"): | |
gr.Markdown(""" | |
### How It Works | |
This tool uses machine learning to predict diamond prices based on the 4 Cs (Carat, Cut, Color, Clarity) | |
and other physical measurements. Adjust the sliders and options to see how different characteristics | |
affect the price! | |
""") | |
# Set up the prediction when button is clicked | |
predict_btn.click( | |
fn=diamond_price_regressor, | |
inputs=[carat, depth, table, x, y, z, cut, color, clarity], | |
outputs=result | |
) | |
# Examples | |
with gr.Accordion("Example Diamonds", open=False): | |
gr.Examples( | |
examples=[ | |
[0.7, 61, 57, 5.7, 5.5, 3.5, "Ideal", "E", "SI2"], | |
[1.0, 62, 56, 6.4, 6.3, 4.0, "Premium", "G", "VS1"], | |
[2.0, 59, 60, 8.0, 7.95, 4.8, "Very Good", "D", "VVS2"] | |
], | |
inputs=[carat, depth, table, x, y, z, cut, color, clarity], | |
outputs=result, | |
fn=diamond_price_regressor | |
) | |
ui.launch() |