Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,18 @@ from sklearn.preprocessing import LabelEncoder
|
|
9 |
import requests
|
10 |
from io import BytesIO
|
11 |
import gdown
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# --- Set page configuration ---
|
14 |
st.set_page_config(
|
@@ -411,6 +423,39 @@ def style_metric_container(label, value):
|
|
411 |
</div>
|
412 |
""", unsafe_allow_html=True)
|
413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
def search_dataset(dataset, make, model=None):
|
415 |
"""
|
416 |
Search the dataset for the specified make and model. If no model is provided,
|
@@ -662,17 +707,21 @@ def predict_with_ranges(inputs, model, label_encoders):
|
|
662 |
'max_price': max_price
|
663 |
}
|
664 |
# --- Main Application ---
|
665 |
-
def main(
|
666 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
667 |
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
""", unsafe_allow_html=True)
|
675 |
-
|
676 |
inputs, predict_button = create_prediction_interface()
|
677 |
|
678 |
if predict_button:
|
@@ -685,24 +734,75 @@ def main(model, label_encoders, dataset):
|
|
685 |
- **Model Prediction**: ${prediction_results['predicted_price']:,.2f}
|
686 |
""")
|
687 |
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
692 |
else:
|
693 |
-
|
694 |
|
695 |
-
|
696 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
697 |
|
698 |
if __name__ == "__main__":
|
699 |
-
|
700 |
-
# Load data and model
|
701 |
-
original_data = load_datasets()
|
702 |
-
model, label_encoders = load_model_and_encodings()
|
703 |
-
|
704 |
-
# Call the main function
|
705 |
-
main(model, label_encoders, original_data)
|
706 |
-
except Exception as e:
|
707 |
-
st.error(f"Error loading data or models: {str(e)}")
|
708 |
-
st.stop()
|
|
|
9 |
import requests
|
10 |
from io import BytesIO
|
11 |
import gdown
|
12 |
+
from PIL import Image
|
13 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
14 |
+
import torch
|
15 |
+
from datetime import datetime
|
16 |
+
|
17 |
+
# --- Set page configuration ---
|
18 |
+
st.set_page_config(
|
19 |
+
page_title="Car Analysis Tool",
|
20 |
+
page_icon="🚗",
|
21 |
+
layout="wide",
|
22 |
+
initial_sidebar_state="expanded"
|
23 |
+
)
|
24 |
|
25 |
# --- Set page configuration ---
|
26 |
st.set_page_config(
|
|
|
423 |
</div>
|
424 |
""", unsafe_allow_html=True)
|
425 |
|
426 |
+
def classify_image(image):
|
427 |
+
try:
|
428 |
+
model_name = "dima806/car_models_image_detection"
|
429 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
430 |
+
model = AutoModelForImageClassification.from_pretrained(model_name)
|
431 |
+
|
432 |
+
inputs = feature_extractor(images=image, return_tensors="pt")
|
433 |
+
|
434 |
+
with torch.no_grad():
|
435 |
+
outputs = model(**inputs)
|
436 |
+
|
437 |
+
logits = outputs.logits
|
438 |
+
predicted_class_idx = logits.argmax(-1).item()
|
439 |
+
predicted_class_label = model.config.id2label[predicted_class_idx]
|
440 |
+
score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
|
441 |
+
|
442 |
+
return [{'label': predicted_class_label, 'score': score}]
|
443 |
+
except Exception as e:
|
444 |
+
st.error(f"Classification error: {e}")
|
445 |
+
return None
|
446 |
+
|
447 |
+
def get_car_overview(brand, model, year):
|
448 |
+
try:
|
449 |
+
prompt = f"Provide an overview of the following car:\nYear: {year}\nMake: {brand}\nModel: {model}\n"
|
450 |
+
response = openai.ChatCompletion.create(
|
451 |
+
model="gpt-3.5-turbo",
|
452 |
+
messages=[{"role": "user", "content": prompt}]
|
453 |
+
)
|
454 |
+
return response.choices[0].message['content']
|
455 |
+
except Exception as e:
|
456 |
+
st.error(f"Error getting car overview: {str(e)}")
|
457 |
+
return None
|
458 |
+
|
459 |
def search_dataset(dataset, make, model=None):
|
460 |
"""
|
461 |
Search the dataset for the specified make and model. If no model is provided,
|
|
|
707 |
'max_price': max_price
|
708 |
}
|
709 |
# --- Main Application ---
|
710 |
+
def main():
|
711 |
+
# Load necessary data and models
|
712 |
+
try:
|
713 |
+
original_data = load_datasets()
|
714 |
+
model, label_encoders = load_model_and_encodings()
|
715 |
+
except Exception as e:
|
716 |
+
st.error(f"Error loading data or models: {str(e)}")
|
717 |
+
st.stop()
|
718 |
|
719 |
+
# Create tabs
|
720 |
+
tab1, tab2 = st.tabs(["Price Prediction", "Image Analysis"])
|
721 |
+
|
722 |
+
with tab1:
|
723 |
+
st.title("Car Price Prediction")
|
724 |
+
# [Previous prediction interface code]
|
|
|
|
|
725 |
inputs, predict_button = create_prediction_interface()
|
726 |
|
727 |
if predict_button:
|
|
|
734 |
- **Model Prediction**: ${prediction_results['predicted_price']:,.2f}
|
735 |
""")
|
736 |
|
737 |
+
# Generate and display the graph
|
738 |
+
fig = create_market_trends_plot_with_model(model, inputs["make"], inputs, label_encoders)
|
739 |
+
if fig:
|
740 |
+
st.pyplot(fig)
|
741 |
+
|
742 |
+
with tab2:
|
743 |
+
st.title("Car Image Analysis")
|
744 |
+
|
745 |
+
# File uploader and camera input
|
746 |
+
uploaded_file = st.file_uploader("Choose a car image", type=["jpg", "jpeg", "png"])
|
747 |
+
camera_image = st.camera_input("Or take a picture of the car")
|
748 |
+
|
749 |
+
# Process the image
|
750 |
+
if uploaded_file is not None:
|
751 |
+
image = Image.open(uploaded_file)
|
752 |
+
elif camera_image is not None:
|
753 |
+
image = Image.open(camera_image)
|
754 |
else:
|
755 |
+
image = None
|
756 |
|
757 |
+
if image is not None:
|
758 |
+
st.image(image, caption='Uploaded Image', use_container_width=True)
|
759 |
+
|
760 |
+
# Classify the image
|
761 |
+
with st.spinner('Analyzing image...'):
|
762 |
+
car_classifications = classify_image(image)
|
763 |
+
|
764 |
+
if car_classifications:
|
765 |
+
top_prediction = car_classifications[0]['label']
|
766 |
+
make_name, model_name = top_prediction.split(' ', 1)
|
767 |
+
current_year = datetime.now().year
|
768 |
+
|
769 |
+
# Display results
|
770 |
+
col1, col2 = st.columns(2)
|
771 |
+
col1.metric("Identified Make", make_name)
|
772 |
+
col2.metric("Identified Model", model_name)
|
773 |
+
|
774 |
+
# Get car overview
|
775 |
+
overview = get_car_overview(make_name, model_name, current_year)
|
776 |
+
if overview:
|
777 |
+
st.subheader("Car Overview")
|
778 |
+
st.write(overview)
|
779 |
+
|
780 |
+
# Use the prediction model with the identified car
|
781 |
+
st.subheader("Price Analysis for Identified Car")
|
782 |
+
auto_inputs = {
|
783 |
+
'year': current_year,
|
784 |
+
'make': make_name.lower(),
|
785 |
+
'model': model_name.lower(),
|
786 |
+
'condition': 'good', # Default values
|
787 |
+
'fuel': 'gas',
|
788 |
+
'odometer': 0,
|
789 |
+
'title_status': 'clean',
|
790 |
+
'transmission': 'automatic',
|
791 |
+
'drive': 'fwd',
|
792 |
+
'size': 'mid-size',
|
793 |
+
'type': 'sedan',
|
794 |
+
'paint_color': 'white'
|
795 |
+
}
|
796 |
+
|
797 |
+
# Get prediction for the identified car
|
798 |
+
prediction_results = predict_with_ranges(auto_inputs, model, label_encoders)
|
799 |
+
|
800 |
+
st.markdown(f"""
|
801 |
+
### Estimated Price Range
|
802 |
+
- **Minimum**: ${prediction_results['min_price']:,.2f}
|
803 |
+
- **Maximum**: ${prediction_results['max_price']:,.2f}
|
804 |
+
- **Predicted**: ${prediction_results['predicted_price']:,.2f}
|
805 |
+
""")
|
806 |
|
807 |
if __name__ == "__main__":
|
808 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|