|
import streamlit as st |
|
import cv2 |
|
from ultralytics import YOLO |
|
import folium |
|
from streamlit_folium import st_folium |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
import numpy as np |
|
import torch |
|
from sklearn.utils.extmath import softmax |
|
import open_clip |
|
import os |
|
|
|
knnpath = '20241204-ams-no-env-open_clip_ViT-H-14-378-quickgelu.npz' |
|
clip_model_name = 'ViT-H-14-378-quickgelu' |
|
pretrained_name = 'dfn5b' |
|
|
|
categories = ['walkability', 'bikeability', 'pleasantness', 'greenness', 'safety'] |
|
|
|
debug = False |
|
|
|
|
|
st.set_page_config( |
|
page_title="Percept", |
|
layout="wide" |
|
) |
|
|
|
|
|
MAPILLARY_ACCESS_TOKEN = os.environ.get('MAPILLARY_ACCESS_TOKEN') |
|
|
|
|
|
if not MAPILLARY_ACCESS_TOKEN: |
|
st.error("Mapillary access token not found. Please configure it in the Space secrets.") |
|
st.stop() |
|
|
|
def detect_and_crop_street(panorama_url, use_yolo=True): |
|
""" |
|
Detect streets in a panoramic image and return a cropped normal-sized image |
|
Args: |
|
panorama_url: URL of the panoramic image |
|
use_yolo: Whether to use YOLOv8 (True) or simple edge detection (False) |
|
Returns: |
|
cropped_image: PIL Image containing the cropped street view |
|
""" |
|
|
|
response = requests.get(panorama_url) |
|
img = Image.open(BytesIO(response.content)) |
|
cv_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) |
|
|
|
if use_yolo: |
|
|
|
model = YOLO('yolov8n.pt') |
|
|
|
|
|
results = model(cv_img) |
|
|
|
|
|
street_boxes = [] |
|
for result in results: |
|
for box, cls in zip(result.boxes.xyxy, result.boxes.cls): |
|
if cls == 0: |
|
street_boxes.append(box.cpu().numpy()) |
|
|
|
if street_boxes: |
|
|
|
largest_box = max(street_boxes, key=lambda box: (box[2]-box[0])*(box[3]-box[1])) |
|
x1, y1, x2, y2 = map(int, largest_box) |
|
|
|
midx = (x2 - x1) / 2 |
|
|
|
padding = 200 |
|
height, width = cv_img.shape[:2] |
|
x1 = max(0, x1 - padding) |
|
y1 = max(0, y1 - padding) |
|
x2 = min(width, x2 + padding) |
|
y2 = min(height, y2 + padding) |
|
|
|
cropped = cv_img[y1:y2, x1:x2] |
|
else: |
|
|
|
cropped = edge_based_crop(cv_img) |
|
else: |
|
cropped = edge_based_crop(cv_img) |
|
|
|
|
|
cropped_pil = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) |
|
|
|
|
|
target_width = 1024 |
|
aspect_ratio = cropped.shape[1] / cropped.shape[0] |
|
target_height = int(target_width / aspect_ratio) |
|
cropped_pil = cropped_pil.resize((target_width, target_height), Image.Resampling.LANCZOS) |
|
|
|
return cropped_pil |
|
|
|
def edge_based_crop(cv_img): |
|
""" |
|
Use edge detection to find and crop around street areas |
|
""" |
|
|
|
gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
blurred = cv2.GaussianBlur(gray, (5, 5), 0) |
|
|
|
|
|
edges = cv2.Canny(blurred, 50, 150) |
|
|
|
|
|
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
if contours: |
|
|
|
largest_contour = max(contours, key=cv2.contourArea) |
|
x, y, w, h = cv2.boundingRect(largest_contour) |
|
|
|
|
|
padding = 200 |
|
height, width = cv_img.shape[:2] |
|
x = max(0, x - padding) |
|
y = max(0, y - padding) |
|
w = min(width - x, w + 2*padding) |
|
h = min(height - y, h + 2*padding) |
|
|
|
return cv_img[y:y+h, x:x+w] |
|
else: |
|
|
|
height, width = cv_img.shape[:2] |
|
center_x = width // 2 |
|
center_y = height // 2 |
|
crop_width = width // 3 |
|
crop_height = height // 3 |
|
return cv_img[center_y-crop_height//2:center_y+crop_height//2, |
|
center_x-crop_width//2:center_x+crop_width//2] |
|
|
|
|
|
def process_panorama(panorama_url): |
|
""" |
|
Process a panoramic image to get a street-centered crop |
|
""" |
|
try: |
|
cropped_image = detect_and_crop_street(panorama_url) |
|
return cropped_image |
|
except Exception as e: |
|
st.error(f"Error processing panorama: {str(e)}") |
|
return None |
|
|
|
def get_bounding_box(lat, lon): |
|
""" |
|
Create a bounding box around a point that extends roughly 25 meters in each direction |
|
at Amsterdam's latitude (52.37°N): |
|
- 0.000224 degrees latitude = 25 meters N/S |
|
- 0.000368 degrees longitude = 25 meters E/W |
|
""" |
|
lat_offset = 0.000224 |
|
lon_offset = 0.000368 |
|
return [ |
|
lon - lon_offset, |
|
lat - lat_offset, |
|
lon + lon_offset, |
|
lat + lat_offset |
|
] |
|
|
|
def get_nearest_image(lat, lon): |
|
""" |
|
Get the nearest Mapillary image to given coordinates |
|
""" |
|
bbox = get_bounding_box(lat, lon) |
|
params = { |
|
'fields': 'id,thumb_1024_url,is_pano', |
|
'limit': 1, |
|
'bbox': f'{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]}' |
|
} |
|
|
|
header = {'Authorization' : 'OAuth {}'.format(MAPILLARY_ACCESS_TOKEN)} |
|
try: |
|
response = requests.get( |
|
"https://graph.mapillary.com/images", |
|
params=params, |
|
headers=header |
|
) |
|
response.raise_for_status() |
|
data = response.json() |
|
|
|
if 'data' in data and len(data['data']) > 0: |
|
return data['data'][0] |
|
return None |
|
|
|
except requests.exceptions.RequestException as e: |
|
st.error(f"Error fetching Mapillary data: {str(e)}") |
|
return None |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
"""Load the OpenCLIP model and return model and processor""" |
|
model, _, preprocess = open_clip.create_model_and_transforms( |
|
clip_model_name, pretrained=pretrained_name |
|
) |
|
tokenizer = open_clip.get_tokenizer(clip_model_name) |
|
return model, preprocess, tokenizer |
|
|
|
def process_image(image, preprocess): |
|
"""Process image and return tensor""" |
|
if isinstance(image, str): |
|
|
|
response = requests.get(image) |
|
image = Image.open(BytesIO(response.content)) |
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
processed_image = preprocess(image).unsqueeze(0) |
|
return processed_image |
|
|
|
def knn_get_score(knn, k, cat, vec): |
|
allvecs = knn[f'{cat}_vecs'] |
|
if debug: st.write('allvecs.shape', allvecs.shape) |
|
scores = knn[f'{cat}_scores'] |
|
if debug: st.write('scores.shape', scores.shape) |
|
|
|
|
|
cos_sim_table = vec @ allvecs.T |
|
if debug: st.write('cos_sim_table.shape', cos_sim_table.shape) |
|
|
|
sortinds = np.flip(np.argsort(cos_sim_table, axis=1), axis=1) |
|
if debug: st.write('sortinds.shape', sortinds.shape) |
|
|
|
kscores = scores[sortinds][:,:k] |
|
if debug: st.write('kscores.shape', kscores.shape) |
|
|
|
|
|
ksims = cos_sim_table[np.expand_dims(np.arange(sortinds.shape[0]), axis=1), sortinds] |
|
ksims = ksims[:,:k] |
|
if debug: st.write('ksims.shape', ksims.shape) |
|
|
|
ksims = softmax(10**ksims) |
|
|
|
kweightedscore = np.sum(kscores * ksims) |
|
return kweightedscore |
|
|
|
|
|
@st.cache_resource |
|
def load_knn(): |
|
return np.load(knnpath) |
|
|
|
def main(): |
|
st.title("Percept: Map Explorer") |
|
|
|
try: |
|
with st.spinner('Loading CLIP model... This may take a moment.'): |
|
model, preprocess, tokenizer = load_model() |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = model.to(device) |
|
except Exception as e: |
|
st.error(f"Error loading model: {str(e)}") |
|
st.info("Please make sure you have enough memory and the correct dependencies installed.") |
|
|
|
with st.spinner('Loading KNN model... This may take a moment.'): |
|
knn = load_knn() |
|
|
|
|
|
amsterdam_coords = [52.3676, 4.9041] |
|
m = folium.Map(location=amsterdam_coords, zoom_start=13) |
|
|
|
|
|
marker_group = folium.FeatureGroup(name="Marker") |
|
m.add_child(marker_group) |
|
|
|
|
|
map_data = st_folium(m, height=400, width=700) |
|
|
|
|
|
if map_data['last_clicked']: |
|
lat = map_data['last_clicked']['lat'] |
|
lng = map_data['last_clicked']['lng'] |
|
|
|
|
|
marker_group.add_child(folium.Marker( |
|
[lat, lng], |
|
popup=f"Selected Location\n{lat:.4f}, {lng:.4f}", |
|
icon=folium.Icon(color="red", icon="info-sign") |
|
)) |
|
|
|
st.write(f"Selected coordinates: {lat:.4f}, {lng:.4f}") |
|
|
|
|
|
with st.spinner('Fetching street view image...'): |
|
image_data = get_nearest_image(lat, lng) |
|
|
|
if image_data: |
|
|
|
try: |
|
if image_data['is_pano']: |
|
st.write('Processing panoramic image') |
|
image = process_panorama(image_data['thumb_1024_url']) |
|
image_bytes = BytesIO() |
|
|
|
image.save(image_bytes, format='JPEG') |
|
image = Image.open(image_bytes) |
|
image_bytes = image_bytes.getvalue() |
|
|
|
else: |
|
response = requests.get(image_data['thumb_1024_url']) |
|
image = Image.open(BytesIO(response.content)) |
|
image_bytes = response.content |
|
st.image(image, caption="Street View", width=400, output_format='JPEG') |
|
|
|
|
|
st.download_button( |
|
label="Download Image", |
|
data=image_bytes, |
|
file_name=f"streetview_{lat}_{lng}.jpg", |
|
mime="image/jpeg" |
|
) |
|
|
|
|
|
with st.spinner('Processing image...'): |
|
processed_image = process_image(image, preprocess) |
|
processed_image = processed_image.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
vec = model.encode_image(processed_image) |
|
|
|
|
|
vec /= vec.norm(dim=-1, keepdim=True) |
|
if debug: st.write(vec.shape) |
|
vec = vec.numpy() |
|
k = 40 |
|
for cat in categories: |
|
st.write(cat, f'rating = {knn_get_score(knn, k, cat, vec):.1f}') |
|
|
|
except Exception as e: |
|
st.error(f"Error displaying image: {str(e)}") |
|
else: |
|
st.warning("No street view images found at this location. Try a different spot.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|