zaid-kamil's picture
Update app.py
1cdee9a verified
import gradio as gr
from joblib import load
import os
import pandas as pd
# load the model
def load_model(path=""):
if os.path.exists(path):
model_dict = load(path)
return model_dict
else:
print("Model not found")
model_dict = load_model('diamond_price.joblib')
# prediction function
def diamond_price_regressor(caret, depth, table, x, y, z, cut, color, clarity):
model = model_dict['model']
target_converter = model_dict['quantile']
input_frame = pd.DataFrame({
'carat': [caret],
'cut': [cut],
'color': [color],
'clarity': [clarity],
'depth': [depth],
'table': [table],
'x': [x],
'y': [y],
'z': [z]
})
print(input_frame)
pred = model.predict(input_frame)
pred = target_converter.inverse_transform(pred.reshape(-1, 1))
print(pred)
return f'Approx price is ${pred[0][0]:.2f}'
cut_choices = ['Ideal', 'Premium', 'Good', 'Very Good', 'Fair']
color_choices = ['E', 'I', 'J', 'H', 'F', 'G', 'D']
clarity_choices = ['SI2', 'SI1', 'VS1', 'VS2', 'VVS2', 'VVS1', 'I1', 'IF']
cut_choices = ['Ideal', 'Premium', 'Good', 'Very Good', 'Fair']
color_choices = ['E', 'I', 'J', 'H', 'F', 'G', 'D']
clarity_choices = ['SI2', 'SI1', 'VS1', 'VS2', 'VVS2', 'VVS1', 'I1', 'IF']
with gr.Blocks() as ui:
gr.HTML("""
<h1>💎 Diamond Price Predictor 💎</h1>
<h3>Find out how much that rock is worth!</h3>
""")
with gr.Row():
with gr.Column(elem_classes="diamond-box"):
gr.Markdown("### Physical Characteristics")
carat = gr.Slider(minimum=0, maximum=10, step=.01, value=.7, label="Carat", info="1 carat = 0.2 grams")
with gr.Row():
x = gr.Slider(minimum=0, maximum=20, step=.01, value=5, label="Length (mm)", info="x")
y = gr.Slider(minimum=0, maximum=20, step=.01, value=5, label="Width (mm)", info="y")
z = gr.Slider(minimum=0, maximum=20, step=.01, value=3.5, label="Height (mm)", info="z")
depth = gr.Slider(minimum=0, maximum=100, step=.01, value=61, label="Depth %", info="Total depth percentage")
table = gr.Slider(minimum=0, maximum=100, step=.01, value=57, label="Table %", info="Width of top relative to widest point")
with gr.Column(elem_classes="diamond-box"):
gr.Markdown("### Quality Ratings")
cut = gr.Radio(cut_choices, label="Cut Quality", value="Ideal", info="Quality of the cut")
color = gr.Dropdown(color_choices, label="Color Grade", value="E", info="D is best (colorless), J is worst")
clarity = gr.Dropdown(clarity_choices, label="Clarity Grade", value="SI2", info="IF (flawless) is best, I1 is worst")
result = gr.Textbox(label="Estimated Price", lines=1)
predict_btn = gr.Button("💰 Calculate Price 💰", variant="primary")
with gr.Row(elem_classes="diamond-box"):
gr.Markdown("""
### How It Works
This tool uses machine learning to predict diamond prices based on the 4 Cs (Carat, Cut, Color, Clarity)
and other physical measurements. Adjust the sliders and options to see how different characteristics
affect the price!
""")
# Set up the prediction when button is clicked
predict_btn.click(
fn=diamond_price_regressor,
inputs=[carat, depth, table, x, y, z, cut, color, clarity],
outputs=result
)
# Examples
with gr.Accordion("Example Diamonds", open=False):
gr.Examples(
examples=[
[0.7, 61, 57, 5.7, 5.5, 3.5, "Ideal", "E", "SI2"],
[1.0, 62, 56, 6.4, 6.3, 4.0, "Premium", "G", "VS1"],
[2.0, 59, 60, 8.0, 7.95, 4.8, "Very Good", "D", "VVS2"]
],
inputs=[carat, depth, table, x, y, z, cut, color, clarity],
outputs=result,
fn=diamond_price_regressor
)
ui.launch()