CosmickVisions commited on
Commit
ea4b833
·
verified ·
1 Parent(s): 15f7102

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +375 -61
app.py CHANGED
@@ -20,6 +20,7 @@ import tempfile
20
  import time
21
  import matplotlib.pyplot as plt
22
  from pathlib import Path
 
23
 
24
  # Set page config
25
  st.set_page_config(
@@ -462,7 +463,7 @@ def list_bigquery_resources():
462
  return resources
463
 
464
  def process_video_file(video_file, analysis_types):
465
- """Process an uploaded video file with enhanced Vision AI detection"""
466
  # Create a temporary file to save the uploaded video
467
  with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file:
468
  temp_file.write(video_file.read())
@@ -491,9 +492,13 @@ def process_video_file(video_file, analysis_types):
491
  if int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) > max_frames:
492
  st.info("⚠️ Video is longer than 10 seconds. Only the first 10 seconds will be processed.")
493
 
494
- # Create video writer
495
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
496
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
 
 
 
 
497
 
498
  # Process every Nth frame to reduce API calls but increase from 10 to 5 for more detail
499
  process_every_n_frames = 5
@@ -556,28 +561,28 @@ def process_video_file(video_file, analysis_types):
556
  for vertex in obj.bounding_poly.normalized_vertices]
557
  box = np.array(box, np.int32).reshape((-1, 1, 2))
558
 
559
- # Draw more detailed box
560
- cv2.polylines(frame, [box], True, (0, 255, 0), 2)
561
 
562
  # Calculate box size for better placement of labels
563
  x_min = min([p[0][0] for p in box])
564
  y_min = min([p[0][1] for p in box])
565
  confidence = int(obj.score * 100)
566
 
567
- # Enhanced label with confidence and border
568
  label_text = f"{obj.name}: {confidence}%"
569
- text_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
570
 
571
- # Background rectangle for text visibility
572
  cv2.rectangle(frame,
573
- (int(x_min), int(y_min) - text_size[1] - 10),
574
- (int(x_min) + text_size[0] + 10, int(y_min)),
575
- (0, 0, 0), -1)
576
 
577
- # Draw the label text
578
  cv2.putText(frame, label_text,
579
  (int(x_min) + 5, int(y_min) - 5),
580
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
581
 
582
  if "Face Detection" in analysis_types:
583
  faces = client.face_detection(image=vision_image)
@@ -587,9 +592,9 @@ def process_video_file(video_file, analysis_types):
587
  for face in faces.face_annotations:
588
  vertices = face.bounding_poly.vertices
589
  points = [(vertex.x, vertex.y) for vertex in vertices]
590
- # Draw face box
591
  pts = np.array(points, np.int32).reshape((-1, 1, 2))
592
- cv2.polylines(frame, [pts], True, (0, 0, 255), 2)
593
 
594
  # Enhanced face info
595
  emotions = []
@@ -609,13 +614,13 @@ def process_video_file(video_file, analysis_types):
609
  # Add detailed emotion text
610
  cv2.putText(frame, emotion_text,
611
  (int(x_min), int(y_min) - 10),
612
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
613
 
614
  # Draw enhanced landmarks
615
  for landmark in face.landmarks:
616
  px = int(landmark.position.x)
617
  py = int(landmark.position.y)
618
- cv2.circle(frame, (px, py), 2, (255, 255, 0), -1)
619
 
620
  if "Text" in analysis_types:
621
  text = client.text_detection(image=vision_image)
@@ -631,25 +636,25 @@ def process_video_file(video_file, analysis_types):
631
  if len(words) > 5:
632
  short_text += "..."
633
 
634
- # Add text summary to top of frame
635
  cv2.rectangle(frame, (10, 60), (10 + len(short_text)*10, 90), (0, 0, 0), -1)
636
  cv2.putText(frame, f"Text: {short_text}",
637
- (10, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
638
 
639
  # Draw text boxes with improved visibility
640
  for text_annot in text.text_annotations[1:]:
641
  box = [(vertex.x, vertex.y) for vertex in text_annot.bounding_poly.vertices]
642
  pts = np.array(box, np.int32).reshape((-1, 1, 2))
643
- cv2.polylines(frame, [pts], True, (255, 0, 0), 1)
644
 
645
  # Add Labels analysis for more detail
646
  if "Labels" in analysis_types:
647
  labels = client.label_detection(image=vision_image, max_results=5)
648
 
649
- # Add labels to the frame
650
  y_pos = 120
651
- cv2.rectangle(frame, (10, y_pos-20), (200, y_pos+20*len(labels.label_annotations)), (0, 0, 0), -1)
652
- cv2.putText(frame, "Scene labels:", (15, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
653
 
654
  # Track stats and show labels
655
  for i, label in enumerate(labels.label_annotations):
@@ -659,15 +664,19 @@ def process_video_file(video_file, analysis_types):
659
  else:
660
  detection_stats["labels"][label.description] = 1
661
 
662
- # Display on frame
663
  cv2.putText(frame, f"- {label.description}: {int(label.score*100)}%",
664
- (15, y_pos + 20*(i+1)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
665
 
666
  except Exception as e:
667
  # Show error on frame
668
  cv2.putText(frame, f"API Error: {str(e)[:30]}",
669
  (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
670
 
 
 
 
 
671
  # Write the frame to output video
672
  out.write(frame)
673
 
@@ -730,8 +739,8 @@ def load_bigquery_table(dataset_id, table_id, limit=1000):
730
  }
731
 
732
  def main():
733
- # Header
734
- st.markdown('<div class="main-header">Google Cloud AI Analyzer</div>', unsafe_allow_html=True)
735
 
736
  # Navigation
737
  selected = option_menu(
@@ -748,6 +757,9 @@ def main():
748
  with st.sidebar:
749
  st.markdown("### Analysis Settings")
750
 
 
 
 
751
  # Analysis types selection
752
  st.write("Choose analysis types:")
753
  analysis_types = []
@@ -774,45 +786,166 @@ def main():
774
  st.info("This application analyzes images using Google Cloud Vision AI. Upload an image to get started.")
775
 
776
  # Main content
777
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
778
-
779
- if uploaded_file is not None:
780
- # Convert uploaded file to image
781
- image = Image.open(uploaded_file)
782
 
783
- # Apply quality adjustment if needed
784
- if quality < 100:
785
- img_byte_arr = io.BytesIO()
786
- image.save(img_byte_arr, format='JPEG', quality=quality)
787
- image = Image.open(img_byte_arr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
788
 
789
- # Show original image
790
- st.markdown('<div class="subheader">Original Image</div>', unsafe_allow_html=True)
791
- st.image(image, use_column_width=True)
792
 
793
- # Add analyze button
794
- if st.button("Analyze Image"):
795
- if not analysis_types:
796
- st.warning("Please select at least one analysis type.")
797
- else:
798
- with st.spinner("Analyzing image..."):
799
- # Call analyze function
800
- annotated_img, labels, objects, text = analyze_image(image, analysis_types)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801
 
802
- # Display results
803
- display_results(annotated_img, labels, objects, text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
804
 
805
- # Add download button for the annotated image
806
- buf = io.BytesIO()
807
- annotated_img.save(buf, format="PNG")
808
- byte_im = buf.getvalue()
809
 
810
- st.download_button(
811
- label="Download Annotated Image",
812
- data=byte_im,
813
- file_name="annotated_image.png",
814
- mime="image/png"
815
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
 
817
  elif selected == "Video Analysis":
818
  st.markdown('<div class="subheader">Video Analysis</div>', unsafe_allow_html=True)
@@ -1138,6 +1271,176 @@ def main():
1138
  st.dataframe(df[num_cols].describe())
1139
  else:
1140
  st.info("Select an existing dataset and table from the sidebar and click 'Load Selected Table', or upload a CSV file in the 'Upload Data' tab.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1141
 
1142
  elif selected == "About":
1143
  st.markdown("## About This App")
@@ -1161,6 +1464,17 @@ def main():
1161
 
1162
  st.info("Note: Make sure your Google Cloud credentials are properly set up to use this application.")
1163
 
 
 
 
 
 
 
 
 
 
 
 
1164
  if __name__ == "__main__":
1165
  # Use GOOGLE_CREDENTIALS directly - no need for file or GOOGLE_APPLICATION_CREDENTIALS
1166
  try:
 
20
  import time
21
  import matplotlib.pyplot as plt
22
  from pathlib import Path
23
+ import plotly.express as px
24
 
25
  # Set page config
26
  st.set_page_config(
 
463
  return resources
464
 
465
  def process_video_file(video_file, analysis_types):
466
+ """Process an uploaded video file with enhanced Vision AI detection and slow down output for better visibility"""
467
  # Create a temporary file to save the uploaded video
468
  with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file:
469
  temp_file.write(video_file.read())
 
492
  if int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) > max_frames:
493
  st.info("⚠️ Video is longer than 10 seconds. Only the first 10 seconds will be processed.")
494
 
495
+ # Slow down the output video by reducing the fps (60% of original speed)
496
+ output_fps = fps * 0.6
497
+ st.info(f"Output video will be slowed down to {output_fps:.1f} FPS (60% of original speed) for better visualization.")
498
+
499
+ # Create video writer with higher quality settings
500
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # MP4 codec
501
+ out = cv2.VideoWriter(output_path, fourcc, output_fps, (width, height), isColor=True)
502
 
503
  # Process every Nth frame to reduce API calls but increase from 10 to 5 for more detail
504
  process_every_n_frames = 5
 
561
  for vertex in obj.bounding_poly.normalized_vertices]
562
  box = np.array(box, np.int32).reshape((-1, 1, 2))
563
 
564
+ # Draw more noticeable box with thicker lines
565
+ cv2.polylines(frame, [box], True, (0, 255, 0), 3)
566
 
567
  # Calculate box size for better placement of labels
568
  x_min = min([p[0][0] for p in box])
569
  y_min = min([p[0][1] for p in box])
570
  confidence = int(obj.score * 100)
571
 
572
+ # Enhanced label with confidence and border - larger text for visibility
573
  label_text = f"{obj.name}: {confidence}%"
574
+ text_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)[0]
575
 
576
+ # Larger background rectangle for text visibility
577
  cv2.rectangle(frame,
578
+ (int(x_min), int(y_min) - text_size[1] - 10),
579
+ (int(x_min) + text_size[0] + 10, int(y_min)),
580
+ (0, 0, 0), -1)
581
 
582
+ # Draw the label text with larger font
583
  cv2.putText(frame, label_text,
584
  (int(x_min) + 5, int(y_min) - 5),
585
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
586
 
587
  if "Face Detection" in analysis_types:
588
  faces = client.face_detection(image=vision_image)
 
592
  for face in faces.face_annotations:
593
  vertices = face.bounding_poly.vertices
594
  points = [(vertex.x, vertex.y) for vertex in vertices]
595
+ # Draw face box with thicker lines
596
  pts = np.array(points, np.int32).reshape((-1, 1, 2))
597
+ cv2.polylines(frame, [pts], True, (0, 0, 255), 3)
598
 
599
  # Enhanced face info
600
  emotions = []
 
614
  # Add detailed emotion text
615
  cv2.putText(frame, emotion_text,
616
  (int(x_min), int(y_min) - 10),
617
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
618
 
619
  # Draw enhanced landmarks
620
  for landmark in face.landmarks:
621
  px = int(landmark.position.x)
622
  py = int(landmark.position.y)
623
+ cv2.circle(frame, (px, py), 3, (255, 255, 0), -1) # Larger circles
624
 
625
  if "Text" in analysis_types:
626
  text = client.text_detection(image=vision_image)
 
636
  if len(words) > 5:
637
  short_text += "..."
638
 
639
+ # Add text summary to top of frame with better visibility
640
  cv2.rectangle(frame, (10, 60), (10 + len(short_text)*10, 90), (0, 0, 0), -1)
641
  cv2.putText(frame, f"Text: {short_text}",
642
+ (10, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
643
 
644
  # Draw text boxes with improved visibility
645
  for text_annot in text.text_annotations[1:]:
646
  box = [(vertex.x, vertex.y) for vertex in text_annot.bounding_poly.vertices]
647
  pts = np.array(box, np.int32).reshape((-1, 1, 2))
648
+ cv2.polylines(frame, [pts], True, (255, 0, 0), 2) # Thicker lines
649
 
650
  # Add Labels analysis for more detail
651
  if "Labels" in analysis_types:
652
  labels = client.label_detection(image=vision_image, max_results=5)
653
 
654
+ # Add labels to the frame with better visibility
655
  y_pos = 120
656
+ cv2.rectangle(frame, (10, y_pos-20), (250, y_pos+20*len(labels.label_annotations)), (0, 0, 0), -1)
657
+ cv2.putText(frame, "Scene labels:", (15, y_pos), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
658
 
659
  # Track stats and show labels
660
  for i, label in enumerate(labels.label_annotations):
 
664
  else:
665
  detection_stats["labels"][label.description] = 1
666
 
667
+ # Display on frame with larger text
668
  cv2.putText(frame, f"- {label.description}: {int(label.score*100)}%",
669
+ (15, y_pos + 20*(i+1)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
670
 
671
  except Exception as e:
672
  # Show error on frame
673
  cv2.putText(frame, f"API Error: {str(e)[:30]}",
674
  (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
675
 
676
+ # Add hint about slowed down speed
677
+ cv2.putText(frame, "Playback: 60% speed for better visualization",
678
+ (width - 400, height - 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 200, 0), 2)
679
+
680
  # Write the frame to output video
681
  out.write(frame)
682
 
 
739
  }
740
 
741
  def main():
742
+ # Header - Updated title
743
+ st.markdown('<div class="main-header">Cosmick Cloud AI Analyzer</div>', unsafe_allow_html=True)
744
 
745
  # Navigation
746
  selected = option_menu(
 
757
  with st.sidebar:
758
  st.markdown("### Analysis Settings")
759
 
760
+ # Add mode selection
761
+ processing_mode = st.radio("Processing Mode", ["Single Image", "Batch Processing (up to 5 images)"])
762
+
763
  # Analysis types selection
764
  st.write("Choose analysis types:")
765
  analysis_types = []
 
786
  st.info("This application analyzes images using Google Cloud Vision AI. Upload an image to get started.")
787
 
788
  # Main content
789
+ if processing_mode == "Single Image":
790
+ st.markdown("## Single Image Analysis")
791
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
 
792
 
793
+ if uploaded_file is not None:
794
+ # Convert uploaded file to image
795
+ image = Image.open(uploaded_file)
796
+
797
+ # Apply quality adjustment if needed
798
+ if quality < 100:
799
+ img_byte_arr = io.BytesIO()
800
+ image.save(img_byte_arr, format='JPEG', quality=quality)
801
+ image = Image.open(img_byte_arr)
802
+
803
+ # Show original image
804
+ st.markdown('<div class="subheader">Original Image</div>', unsafe_allow_html=True)
805
+ st.image(image, use_column_width=True)
806
+
807
+ # Add analyze button
808
+ if st.button("Analyze Image"):
809
+ if not analysis_types:
810
+ st.warning("Please select at least one analysis type.")
811
+ else:
812
+ with st.spinner("Analyzing image..."):
813
+ # Call analyze function
814
+ annotated_img, labels, objects, text = analyze_image(image, analysis_types)
815
+
816
+ # Display results
817
+ display_results(annotated_img, labels, objects, text)
818
+
819
+ # Add download button for the annotated image
820
+ buf = io.BytesIO()
821
+ annotated_img.save(buf, format="PNG")
822
+ byte_im = buf.getvalue()
823
+
824
+ st.download_button(
825
+ label="Download Annotated Image",
826
+ data=byte_im,
827
+ file_name="annotated_image.png",
828
+ mime="image/png"
829
+ )
830
+
831
+ else: # Batch Processing mode
832
+ st.markdown("## Batch Image Analysis")
833
+ st.info("Upload up to 5 images for batch processing.")
834
 
835
+ uploaded_files = st.file_uploader("Choose images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
 
 
836
 
837
+ if uploaded_files:
838
+ # Limit to 5 images
839
+ if len(uploaded_files) > 5:
840
+ st.warning("Maximum 5 images allowed. Only the first 5 will be processed.")
841
+ uploaded_files = uploaded_files[:5]
842
+
843
+ # Display thumbnails of uploaded images
844
+ st.markdown('<div class="subheader">Uploaded Images</div>', unsafe_allow_html=True)
845
+ cols = st.columns(len(uploaded_files))
846
+ for i, uploaded_file in enumerate(uploaded_files):
847
+ with cols[i]:
848
+ image = Image.open(uploaded_file)
849
+ st.image(image, caption=f"Image {i+1}", use_column_width=True)
850
+
851
+ # Add analyze button for batch processing
852
+ if st.button("Analyze All Images"):
853
+ if not analysis_types:
854
+ st.warning("Please select at least one analysis type.")
855
+ else:
856
+ # Initialize containers for batch summary
857
+ all_labels = {}
858
+ all_objects = {}
859
 
860
+ # Process each image
861
+ for i, uploaded_file in enumerate(uploaded_files):
862
+ with st.spinner(f"Analyzing image {i+1} of {len(uploaded_files)}..."):
863
+ # Convert uploaded file to image
864
+ image = Image.open(uploaded_file)
865
+
866
+ # Apply quality adjustment if needed
867
+ if quality < 100:
868
+ img_byte_arr = io.BytesIO()
869
+ image.save(img_byte_arr, format='JPEG', quality=quality)
870
+ image = Image.open(img_byte_arr)
871
+
872
+ # Analyze image
873
+ annotated_img, labels, objects, text = analyze_image(image, analysis_types)
874
+
875
+ # Update batch summaries
876
+ for label, confidence in labels.items():
877
+ if label in all_labels:
878
+ all_labels[label] = max(all_labels[label], confidence)
879
+ else:
880
+ all_labels[label] = confidence
881
+
882
+ for obj, confidence in objects.items():
883
+ if obj in all_objects:
884
+ all_objects[obj] = max(all_objects[obj], confidence)
885
+ else:
886
+ all_objects[obj] = confidence
887
+
888
+ # Create expander for each image result
889
+ with st.expander(f"Results for Image {i+1}", expanded=i==0):
890
+ # Display results for this image
891
+ display_results(annotated_img, labels, objects, text)
892
+
893
+ # Add download button for each annotated image
894
+ buf = io.BytesIO()
895
+ annotated_img.save(buf, format="PNG")
896
+ byte_im = buf.getvalue()
897
+
898
+ st.download_button(
899
+ label=f"Download Annotated Image {i+1}",
900
+ data=byte_im,
901
+ file_name=f"annotated_image_{i+1}.png",
902
+ mime="image/png"
903
+ )
904
+
905
+ # Display batch summary
906
+ st.markdown('<div class="subheader">Batch Analysis Summary</div>', unsafe_allow_html=True)
907
 
908
+ col1, col2 = st.columns(2)
 
 
 
909
 
910
+ with col1:
911
+ if all_labels:
912
+ st.markdown("#### Common Labels Across Images")
913
+ # Sort by confidence
914
+ sorted_labels = dict(sorted(all_labels.items(), key=lambda x: x[1], reverse=True))
915
+ for label, confidence in sorted_labels.items():
916
+ st.markdown(f'<div class="label-item">{label}: {confidence}%</div>', unsafe_allow_html=True)
917
+
918
+ with col2:
919
+ if all_objects:
920
+ st.markdown("#### Common Objects Across Images")
921
+ # Sort by confidence
922
+ sorted_objects = dict(sorted(all_objects.items(), key=lambda x: x[1], reverse=True))
923
+ for obj, confidence in sorted_objects.items():
924
+ st.markdown(f'<div class="object-item">{obj}: {confidence}%</div>', unsafe_allow_html=True)
925
+
926
+ # Create visualization for batch summary if there are labels or objects
927
+ if all_labels or all_objects:
928
+ st.markdown("#### Visual Summary")
929
+
930
+ # Create label chart
931
+ if all_labels:
932
+ fig_labels = px.bar(
933
+ x=list(all_labels.keys()),
934
+ y=list(all_labels.values()),
935
+ labels={'x': 'Label', 'y': 'Confidence (%)'},
936
+ title='Top Labels Across All Images'
937
+ )
938
+ st.plotly_chart(fig_labels)
939
+
940
+ # Create object chart
941
+ if all_objects:
942
+ fig_objects = px.bar(
943
+ x=list(all_objects.keys()),
944
+ y=list(all_objects.values()),
945
+ labels={'x': 'Object', 'y': 'Confidence (%)'},
946
+ title='Top Objects Across All Images'
947
+ )
948
+ st.plotly_chart(fig_objects)
949
 
950
  elif selected == "Video Analysis":
951
  st.markdown('<div class="subheader">Video Analysis</div>', unsafe_allow_html=True)
 
1271
  st.dataframe(df[num_cols].describe())
1272
  else:
1273
  st.info("Select an existing dataset and table from the sidebar and click 'Load Selected Table', or upload a CSV file in the 'Upload Data' tab.")
1274
+
1275
+ with upload_tab:
1276
+ st.markdown("### Upload Data to BigQuery")
1277
+
1278
+ # File uploader for CSV files
1279
+ uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
1280
+
1281
+ if uploaded_file is not None:
1282
+ # Display file details
1283
+ file_details = {
1284
+ "Filename": uploaded_file.name,
1285
+ "File size": f"{uploaded_file.size / 1024:.2f} KB"
1286
+ }
1287
+
1288
+ # Show file preview
1289
+ try:
1290
+ df_preview = pd.read_csv(uploaded_file)
1291
+ st.write("### File Preview")
1292
+ st.dataframe(df_preview.head(5))
1293
+
1294
+ # Store dataframe in session state for other tabs
1295
+ st.session_state["query_results"] = df_preview
1296
+
1297
+ # Upload button
1298
+ if st.button("Upload to BigQuery"):
1299
+ with st.spinner("Uploading to BigQuery..."):
1300
+ try:
1301
+ # Upload the file
1302
+ append = replace_data == "Append to existing data"
1303
+ result = upload_csv_to_bigquery(uploaded_file, dataset_id, table_id, append=append)
1304
+
1305
+ # Show success message
1306
+ st.success(f"Successfully uploaded to {dataset_id}.{table_id}")
1307
+ st.write(f"Rows: {result['num_rows']}")
1308
+ st.write(f"Size: {result['size_bytes'] / 1024:.2f} KB")
1309
+ st.write(f"Schema: {', '.join(result['schema'])}")
1310
+
1311
+ # Store table info in session state
1312
+ st.session_state["table_info"] = {
1313
+ "dataset_id": dataset_id,
1314
+ "table_id": table_id,
1315
+ "schema": result["schema"]
1316
+ }
1317
+ except Exception as e:
1318
+ st.error(f"Error uploading to BigQuery: {str(e)}")
1319
+ except Exception as e:
1320
+ st.error(f"Error reading CSV file: {str(e)}")
1321
+ else:
1322
+ st.info("Upload a CSV file to load data into BigQuery")
1323
+
1324
+ with query_tab:
1325
+ st.markdown("### Query BigQuery Data")
1326
+
1327
+ if "query_results" in st.session_state and "table_info" in st.session_state:
1328
+ # Display info about the loaded data
1329
+ table_info = st.session_state["table_info"]
1330
+ st.write(f"Working with table: **{table_info['dataset_id']}.{table_info['table_id']}**")
1331
+
1332
+ # Query input
1333
+ default_query = f"SELECT * FROM `{credentials.project_id}.{table_info['dataset_id']}.{table_info['table_id']}` LIMIT 100"
1334
+ query = st.text_area("SQL Query", default_query, height=100)
1335
+
1336
+ # Execute query button
1337
+ if st.button("Run Query"):
1338
+ with st.spinner("Executing query..."):
1339
+ try:
1340
+ # Run the query
1341
+ results = run_bigquery(query)
1342
+
1343
+ # Store results in session state
1344
+ st.session_state["query_results"] = results
1345
+
1346
+ # Display results
1347
+ st.write("### Query Results")
1348
+ st.dataframe(results)
1349
+
1350
+ # Download button for results
1351
+ csv = results.to_csv(index=False)
1352
+ st.download_button(
1353
+ label="Download Results as CSV",
1354
+ data=csv,
1355
+ file_name="query_results.csv",
1356
+ mime="text/csv"
1357
+ )
1358
+ except Exception as e:
1359
+ st.error(f"Error executing query: {str(e)}")
1360
+ else:
1361
+ st.info("Load a table from BigQuery or upload a CSV file first")
1362
+
1363
+ with visualization_tab:
1364
+ st.markdown("### Visualize BigQuery Data")
1365
+
1366
+ if "query_results" in st.session_state and not st.session_state["query_results"].empty:
1367
+ df = st.session_state["query_results"]
1368
+
1369
+ # Chart type selection
1370
+ chart_type = st.selectbox(
1371
+ "Select Chart Type",
1372
+ ["Bar Chart", "Line Chart", "Scatter Plot", "Histogram", "Pie Chart"]
1373
+ )
1374
+
1375
+ # Column selection based on data types
1376
+ numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns.tolist()
1377
+ all_cols = df.columns.tolist()
1378
+
1379
+ if len(numeric_cols) < 1:
1380
+ st.warning("No numeric columns available for visualization")
1381
+ else:
1382
+ if chart_type in ["Bar Chart", "Line Chart", "Scatter Plot"]:
1383
+ col1, col2 = st.columns(2)
1384
+
1385
+ with col1:
1386
+ x_axis = st.selectbox("X-axis", all_cols)
1387
+
1388
+ with col2:
1389
+ y_axis = st.selectbox("Y-axis", numeric_cols)
1390
+
1391
+ # Optional: Grouping/color dimension
1392
+ color_dim = st.selectbox("Color Dimension (Optional)", ["None"] + all_cols)
1393
+
1394
+ # Generate the visualization based on selection
1395
+ if st.button("Generate Visualization"):
1396
+ st.write(f"### {chart_type}: {y_axis} by {x_axis}")
1397
+
1398
+ if chart_type == "Bar Chart":
1399
+ if color_dim != "None":
1400
+ fig = px.bar(df, x=x_axis, y=y_axis, color=color_dim,
1401
+ title=f"{y_axis} by {x_axis}")
1402
+ else:
1403
+ fig = px.bar(df, x=x_axis, y=y_axis, title=f"{y_axis} by {x_axis}")
1404
+ st.plotly_chart(fig)
1405
+
1406
+ elif chart_type == "Line Chart":
1407
+ if color_dim != "None":
1408
+ fig = px.line(df, x=x_axis, y=y_axis, color=color_dim,
1409
+ title=f"{y_axis} by {x_axis}")
1410
+ else:
1411
+ fig = px.line(df, x=x_axis, y=y_axis, title=f"{y_axis} by {x_axis}")
1412
+ st.plotly_chart(fig)
1413
+
1414
+ elif chart_type == "Scatter Plot":
1415
+ if color_dim != "None":
1416
+ fig = px.scatter(df, x=x_axis, y=y_axis, color=color_dim,
1417
+ title=f"{y_axis} vs {x_axis}")
1418
+ else:
1419
+ fig = px.scatter(df, x=x_axis, y=y_axis, title=f"{y_axis} vs {x_axis}")
1420
+ st.plotly_chart(fig)
1421
+
1422
+ elif chart_type == "Histogram":
1423
+ column = st.selectbox("Select Column", numeric_cols)
1424
+ bins = st.slider("Number of Bins", min_value=5, max_value=100, value=20)
1425
+
1426
+ if st.button("Generate Visualization"):
1427
+ st.write(f"### Histogram of {column}")
1428
+ fig = px.histogram(df, x=column, nbins=bins, title=f"Distribution of {column}")
1429
+ st.plotly_chart(fig)
1430
+
1431
+ elif chart_type == "Pie Chart":
1432
+ column = st.selectbox("Category Column", all_cols)
1433
+ value_col = st.selectbox("Value Column", numeric_cols)
1434
+
1435
+ if st.button("Generate Visualization"):
1436
+ # Aggregate the data if needed
1437
+ pie_data = df.groupby(column)[value_col].sum().reset_index()
1438
+ st.write(f"### Pie Chart: {value_col} by {column}")
1439
+ fig = px.pie(pie_data, names=column, values=value_col,
1440
+ title=f"{value_col} by {column}")
1441
+ st.plotly_chart(fig)
1442
+ else:
1443
+ st.info("Load a table from BigQuery or upload a CSV file first")
1444
 
1445
  elif selected == "About":
1446
  st.markdown("## About This App")
 
1464
 
1465
  st.info("Note: Make sure your Google Cloud credentials are properly set up to use this application.")
1466
 
1467
+ # At the end of the main function, add this footer
1468
+ # Add footer with attribution
1469
+ st.markdown("""
1470
+ <div style='position: fixed; bottom: 0; width: 100%; background-color: #f8f9fa;
1471
+ text-align: center; padding: 10px; border-top: 1px solid #e9ecef;'>
1472
+ <p style='color: #6c757d; font-size: 14px;'>
1473
+ Powered by Google Cloud with additional tools
1474
+ </p>
1475
+ </div>
1476
+ """, unsafe_allow_html=True)
1477
+
1478
  if __name__ == "__main__":
1479
  # Use GOOGLE_CREDENTIALS directly - no need for file or GOOGLE_APPLICATION_CREDENTIALS
1480
  try: