Cipher29 commited on
Commit
3f605a7
·
verified ·
1 Parent(s): 81334ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -27
app.py CHANGED
@@ -9,6 +9,18 @@ from sklearn.preprocessing import LabelEncoder
9
  import requests
10
  from io import BytesIO
11
  import gdown
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # --- Set page configuration ---
14
  st.set_page_config(
@@ -411,6 +423,39 @@ def style_metric_container(label, value):
411
  </div>
412
  """, unsafe_allow_html=True)
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  def search_dataset(dataset, make, model=None):
415
  """
416
  Search the dataset for the specified make and model. If no model is provided,
@@ -662,17 +707,21 @@ def predict_with_ranges(inputs, model, label_encoders):
662
  'max_price': max_price
663
  }
664
  # --- Main Application ---
665
- def main(model, label_encoders, dataset):
666
- col1, col2 = st.columns([2, 1])
 
 
 
 
 
 
667
 
668
- with col1:
669
- st.markdown("""
670
- <h1 style='text-align: center;'>The Guide 🚗</h1>
671
- <p style='text-align: center; color: #666; font-size: 1.1rem; margin-bottom: 2rem;'>
672
- A cutting-edge data science project leveraging machine learning to detect which car would be best for you.
673
- </p>
674
- """, unsafe_allow_html=True)
675
-
676
  inputs, predict_button = create_prediction_interface()
677
 
678
  if predict_button:
@@ -685,24 +734,75 @@ def main(model, label_encoders, dataset):
685
  - **Model Prediction**: ${prediction_results['predicted_price']:,.2f}
686
  """)
687
 
688
- # Generate and display the graph
689
- fig = create_market_trends_plot_with_model(model, inputs["make"], inputs, label_encoders)
690
- if fig:
691
- st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
692
  else:
693
- st.warning("No graph generated. Please check your data or selection.")
694
 
695
- with col2:
696
- create_assistant_section(dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697
 
698
  if __name__ == "__main__":
699
- try:
700
- # Load data and model
701
- original_data = load_datasets()
702
- model, label_encoders = load_model_and_encodings()
703
-
704
- # Call the main function
705
- main(model, label_encoders, original_data)
706
- except Exception as e:
707
- st.error(f"Error loading data or models: {str(e)}")
708
- st.stop()
 
9
  import requests
10
  from io import BytesIO
11
  import gdown
12
+ from PIL import Image
13
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
14
+ import torch
15
+ from datetime import datetime
16
+
17
+ # --- Set page configuration ---
18
+ st.set_page_config(
19
+ page_title="Car Analysis Tool",
20
+ page_icon="🚗",
21
+ layout="wide",
22
+ initial_sidebar_state="expanded"
23
+ )
24
 
25
  # --- Set page configuration ---
26
  st.set_page_config(
 
423
  </div>
424
  """, unsafe_allow_html=True)
425
 
426
+ def classify_image(image):
427
+ try:
428
+ model_name = "dima806/car_models_image_detection"
429
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
430
+ model = AutoModelForImageClassification.from_pretrained(model_name)
431
+
432
+ inputs = feature_extractor(images=image, return_tensors="pt")
433
+
434
+ with torch.no_grad():
435
+ outputs = model(**inputs)
436
+
437
+ logits = outputs.logits
438
+ predicted_class_idx = logits.argmax(-1).item()
439
+ predicted_class_label = model.config.id2label[predicted_class_idx]
440
+ score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
441
+
442
+ return [{'label': predicted_class_label, 'score': score}]
443
+ except Exception as e:
444
+ st.error(f"Classification error: {e}")
445
+ return None
446
+
447
+ def get_car_overview(brand, model, year):
448
+ try:
449
+ prompt = f"Provide an overview of the following car:\nYear: {year}\nMake: {brand}\nModel: {model}\n"
450
+ response = openai.ChatCompletion.create(
451
+ model="gpt-3.5-turbo",
452
+ messages=[{"role": "user", "content": prompt}]
453
+ )
454
+ return response.choices[0].message['content']
455
+ except Exception as e:
456
+ st.error(f"Error getting car overview: {str(e)}")
457
+ return None
458
+
459
  def search_dataset(dataset, make, model=None):
460
  """
461
  Search the dataset for the specified make and model. If no model is provided,
 
707
  'max_price': max_price
708
  }
709
  # --- Main Application ---
710
+ def main():
711
+ # Load necessary data and models
712
+ try:
713
+ original_data = load_datasets()
714
+ model, label_encoders = load_model_and_encodings()
715
+ except Exception as e:
716
+ st.error(f"Error loading data or models: {str(e)}")
717
+ st.stop()
718
 
719
+ # Create tabs
720
+ tab1, tab2 = st.tabs(["Price Prediction", "Image Analysis"])
721
+
722
+ with tab1:
723
+ st.title("Car Price Prediction")
724
+ # [Previous prediction interface code]
 
 
725
  inputs, predict_button = create_prediction_interface()
726
 
727
  if predict_button:
 
734
  - **Model Prediction**: ${prediction_results['predicted_price']:,.2f}
735
  """)
736
 
737
+ # Generate and display the graph
738
+ fig = create_market_trends_plot_with_model(model, inputs["make"], inputs, label_encoders)
739
+ if fig:
740
+ st.pyplot(fig)
741
+
742
+ with tab2:
743
+ st.title("Car Image Analysis")
744
+
745
+ # File uploader and camera input
746
+ uploaded_file = st.file_uploader("Choose a car image", type=["jpg", "jpeg", "png"])
747
+ camera_image = st.camera_input("Or take a picture of the car")
748
+
749
+ # Process the image
750
+ if uploaded_file is not None:
751
+ image = Image.open(uploaded_file)
752
+ elif camera_image is not None:
753
+ image = Image.open(camera_image)
754
  else:
755
+ image = None
756
 
757
+ if image is not None:
758
+ st.image(image, caption='Uploaded Image', use_container_width=True)
759
+
760
+ # Classify the image
761
+ with st.spinner('Analyzing image...'):
762
+ car_classifications = classify_image(image)
763
+
764
+ if car_classifications:
765
+ top_prediction = car_classifications[0]['label']
766
+ make_name, model_name = top_prediction.split(' ', 1)
767
+ current_year = datetime.now().year
768
+
769
+ # Display results
770
+ col1, col2 = st.columns(2)
771
+ col1.metric("Identified Make", make_name)
772
+ col2.metric("Identified Model", model_name)
773
+
774
+ # Get car overview
775
+ overview = get_car_overview(make_name, model_name, current_year)
776
+ if overview:
777
+ st.subheader("Car Overview")
778
+ st.write(overview)
779
+
780
+ # Use the prediction model with the identified car
781
+ st.subheader("Price Analysis for Identified Car")
782
+ auto_inputs = {
783
+ 'year': current_year,
784
+ 'make': make_name.lower(),
785
+ 'model': model_name.lower(),
786
+ 'condition': 'good', # Default values
787
+ 'fuel': 'gas',
788
+ 'odometer': 0,
789
+ 'title_status': 'clean',
790
+ 'transmission': 'automatic',
791
+ 'drive': 'fwd',
792
+ 'size': 'mid-size',
793
+ 'type': 'sedan',
794
+ 'paint_color': 'white'
795
+ }
796
+
797
+ # Get prediction for the identified car
798
+ prediction_results = predict_with_ranges(auto_inputs, model, label_encoders)
799
+
800
+ st.markdown(f"""
801
+ ### Estimated Price Range
802
+ - **Minimum**: ${prediction_results['min_price']:,.2f}
803
+ - **Maximum**: ${prediction_results['max_price']:,.2f}
804
+ - **Predicted**: ${prediction_results['predicted_price']:,.2f}
805
+ """)
806
 
807
  if __name__ == "__main__":
808
+ main()