|
import streamlit as st
|
|
from PIL import Image
|
|
import torch
|
|
from torchvision import transforms
|
|
import pydeck as pdk
|
|
from geopy.geocoders import Nominatim
|
|
import time
|
|
import requests
|
|
from io import BytesIO
|
|
import reverse_geocoder as rg
|
|
from bs4 import BeautifulSoup
|
|
from urllib.parse import urljoin
|
|
from models.huggingface import Geolocalizer
|
|
import spacy
|
|
from collections import Counter
|
|
from spacy.cli import download
|
|
from typing import Tuple, List, Optional, Union, Dict
|
|
|
|
|
|
def load_spacy_model(model_name: str = "en_core_web_md") -> spacy.Language:
|
|
"""
|
|
Load the specified spaCy model.
|
|
|
|
Args:
|
|
model_name (str): Name of the spaCy model to load.
|
|
|
|
Returns:
|
|
spacy.Language: Loaded spaCy model.
|
|
"""
|
|
try:
|
|
return spacy.load(model_name)
|
|
except IOError:
|
|
print(f"Model {model_name} not found, downloading...")
|
|
download(model_name)
|
|
return spacy.load(model_name)
|
|
|
|
|
|
nlp = load_spacy_model()
|
|
|
|
IMAGE_SIZE = (224, 224)
|
|
GEOLOC_MODEL_NAME = "osv5m/baseline"
|
|
|
|
|
|
@st.cache_resource(show_spinner=True)
|
|
def load_geoloc_model() -> Optional[Geolocalizer]:
|
|
"""
|
|
Load the geolocation model.
|
|
|
|
Returns:
|
|
Optional[Geolocalizer]: Loaded geolocation model or None if loading fails.
|
|
"""
|
|
with st.spinner('Loading model...'):
|
|
try:
|
|
model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME)
|
|
model.eval()
|
|
return model
|
|
except Exception as e:
|
|
st.error(f"Failed to load the model: {e}")
|
|
return None
|
|
|
|
|
|
def most_frequent_locations(text: str) -> Tuple[str, List[str]]:
|
|
"""
|
|
Find the most frequent locations mentioned in the text.
|
|
|
|
Args:
|
|
text (str): Input text to analyze.
|
|
|
|
Returns:
|
|
Tuple[str, List[str]]: Description of the most mentioned locations and a list of those locations.
|
|
"""
|
|
doc = nlp(text)
|
|
locations = []
|
|
|
|
for ent in doc.ents:
|
|
if ent.label_ in ['LOC', 'GPE']:
|
|
print(f"Entity: {ent.text} | Label: {ent.label_} | Sentence: {ent.sent}")
|
|
locations.append(ent.text)
|
|
|
|
if locations:
|
|
location_counts = Counter(locations)
|
|
most_common_locations = location_counts.most_common(2)
|
|
common_locations_str = ', '.join([f"{loc[0]} ({loc[1]} occurrences)" for loc in most_common_locations])
|
|
return f"Most Mentioned Locations: {common_locations_str}", [loc[0] for loc in most_common_locations]
|
|
else:
|
|
return "No locations found", []
|
|
|
|
|
|
def transform_image(image: Image) -> torch.Tensor:
|
|
"""
|
|
Transform the input image for model prediction.
|
|
|
|
Args:
|
|
image (Image): Input image.
|
|
|
|
Returns:
|
|
torch.Tensor: Transformed image tensor.
|
|
"""
|
|
transform = transforms.Compose([
|
|
transforms.Resize(IMAGE_SIZE),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
return transform(image).unsqueeze(0)
|
|
|
|
|
|
def check_location_match(location_query: dict, most_common_locations: List[str]) -> bool:
|
|
"""
|
|
Check if the predicted location matches any of the most common locations.
|
|
|
|
Args:
|
|
location_query (dict): Predicted location details.
|
|
most_common_locations (List[str]): List of most common locations.
|
|
|
|
Returns:
|
|
bool: True if a match is found, False otherwise.
|
|
"""
|
|
name = location_query['name']
|
|
admin1 = location_query['admin1']
|
|
cc = location_query['cc']
|
|
|
|
for loc in most_common_locations:
|
|
if name in loc and admin1 in loc and cc in loc:
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_city_geojson(location_name: str) -> Optional[dict]:
|
|
"""
|
|
Fetch the GeoJSON data for the specified city.
|
|
|
|
Args:
|
|
location_name (str): Name of the city.
|
|
|
|
Returns:
|
|
Optional[dict]: GeoJSON data of the city or None if fetching fails.
|
|
"""
|
|
geolocator = Nominatim(user_agent="predictGeolocforImage")
|
|
try:
|
|
location = geolocator.geocode(location_name, geometry='geojson')
|
|
return location.raw['geojson'] if location else None
|
|
except Exception as e:
|
|
st.error(f"Failed to geocode location: {e}")
|
|
return None
|
|
|
|
|
|
def get_media(url: str) -> Optional[List[Tuple[str, str]]]:
|
|
"""
|
|
Fetch media URLs and associated text from the specified URL.
|
|
|
|
Args:
|
|
url (str): URL to fetch media from.
|
|
|
|
Returns:
|
|
Optional[List[Tuple[str, str]]]: List of tuples containing media URLs and associated text or None if fetching fails.
|
|
"""
|
|
try:
|
|
response = requests.get(url)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return [(media['media_url'], entry['full_text'])
|
|
for entry in data for media in entry.get('media', []) if 'media_url' in media]
|
|
except requests.RequestException as e:
|
|
st.error(f"Failed to fetch media URL: {e}")
|
|
return None
|
|
|
|
|
|
def predict_location(image: Image, model: Geolocalizer) -> Optional[Tuple[List[float], dict, Optional[dict], float]]:
|
|
"""
|
|
Predict the location from the input image using the specified model.
|
|
|
|
Args:
|
|
image (Image): Input image.
|
|
model (Geolocalizer): Geolocation model.
|
|
|
|
Returns:
|
|
Optional[Tuple[List[float], dict, Optional[dict], float]]: Predicted GPS coordinates, location query, city GeoJSON data, and processing time or None if prediction fails.
|
|
"""
|
|
with st.spinner('Processing image and predicting location...'):
|
|
start_time = time.time()
|
|
try:
|
|
img_tensor = transform_image(image)
|
|
gps_radians = model(img_tensor)
|
|
gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist()
|
|
location_query = rg.search((gps_degrees[0], gps_degrees[1]))[0]
|
|
location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}"
|
|
city_geojson = get_city_geojson(location_name)
|
|
processing_time = time.time() - start_time
|
|
return gps_degrees, location_query, city_geojson, processing_time
|
|
except Exception as e:
|
|
st.error(f"Failed to predict the location: {e}")
|
|
return None
|
|
|
|
|
|
def display_map(city_geojson: dict, gps_degrees: List[float]) -> None:
|
|
"""
|
|
Display a map with the specified city GeoJSON data and GPS coordinates.
|
|
|
|
Args:
|
|
city_geojson (dict): GeoJSON data of the city.
|
|
gps_degrees (List[float]): GPS coordinates.
|
|
"""
|
|
map_view = pdk.Deck(
|
|
map_style='mapbox://styles/mapbox/light-v9',
|
|
initial_view_state=pdk.ViewState(
|
|
latitude=gps_degrees[0],
|
|
longitude=gps_degrees[1],
|
|
zoom=8,
|
|
pitch=0,
|
|
),
|
|
layers=[
|
|
pdk.Layer(
|
|
'GeoJsonLayer',
|
|
data=city_geojson,
|
|
get_fill_color=[255, 180, 0, 140],
|
|
pickable=True,
|
|
stroked=True,
|
|
filled=True,
|
|
extruded=False,
|
|
line_width_min_pixels=1,
|
|
),
|
|
],
|
|
)
|
|
st.pydeck_chart(map_view)
|
|
|
|
|
|
def display_image(image_url: str) -> None:
|
|
"""
|
|
Display an image from the specified URL.
|
|
|
|
Args:
|
|
image_url (str): URL of the image.
|
|
"""
|
|
try:
|
|
response = requests.get(image_url)
|
|
response.raise_for_status()
|
|
image_bytes = BytesIO(response.content)
|
|
st.image(image_bytes, caption=f'Image from URL: {image_url}', use_column_width=True)
|
|
except requests.RequestException as e:
|
|
st.error(f"Failed to fetch image at URL {image_url}: {e}")
|
|
except Exception as e:
|
|
st.error(f"An error occurred: {e}")
|
|
|
|
|
|
def scrape_webpage(url: str) -> Union[Tuple[Optional[str], Optional[List[str]]], Tuple[None, None]]:
|
|
"""
|
|
Scrape the specified webpage for text and images.
|
|
|
|
Args:
|
|
url (str): URL of the webpage to scrape.
|
|
|
|
Returns:
|
|
Union[Tuple[Optional[str], Optional[List[str]]], Tuple[None, None]]: Extracted text and list of image URLs or None if scraping fails.
|
|
"""
|
|
with st.spinner('Scraping web page...'):
|
|
try:
|
|
response = requests.get(url)
|
|
response.raise_for_status()
|
|
soup = BeautifulSoup(response.content, 'html.parser')
|
|
base_url = url
|
|
text = ''.join(p.text for p in soup.find_all('p'))
|
|
images = [urljoin(base_url, img['src']) for img in soup.find_all('img') if 'src' in img.attrs]
|
|
return text, images
|
|
except requests.RequestException as e:
|
|
st.error(f"Failed to fetch and parse the URL: {e}")
|
|
return None, None
|
|
|
|
|
|
def main() -> None:
|
|
"""
|
|
Main function to run the Streamlit app.
|
|
"""
|
|
st.title('Welcome to Geolocation Guesstimation Demo 👋')
|
|
|
|
page = st.sidebar.selectbox(
|
|
"Choose your action:",
|
|
("Home", "Images", "Social Media", "Web Pages"),
|
|
index=0
|
|
)
|
|
|
|
st.sidebar.success("Select a demo above.")
|
|
st.sidebar.info(
|
|
"""
|
|
- Web App URL: <https://yunusserhat-guesstimatelocation.hf.space/>
|
|
"""
|
|
)
|
|
|
|
st.sidebar.title("Contact")
|
|
st.sidebar.info(
|
|
"""
|
|
Yunus Serhat Bıçakçı at [yunusserhat.com](https://yunusserhat.com) | [GitHub](https://github.com/yunusserhat) | [Twitter](https://twitter.com/yunusserhat) | [LinkedIn](https://www.linkedin.com/in/yunusserhat)
|
|
"""
|
|
)
|
|
|
|
if page == "Home":
|
|
st.write("Welcome to the Geolocation Predictor. Please select an action from the sidebar dropdown.")
|
|
|
|
elif page == "Images":
|
|
upload_images_page()
|
|
|
|
elif page == "Social Media":
|
|
social_media_page()
|
|
|
|
elif page == "Web Pages":
|
|
web_page_url_page()
|
|
|
|
|
|
def upload_images_page() -> None:
|
|
"""
|
|
Display the image upload page for geolocation prediction.
|
|
"""
|
|
st.header("Image Upload for Geolocation Prediction")
|
|
uploaded_files = st.file_uploader("Choose images...", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
|
|
if uploaded_files:
|
|
for idx, file in enumerate(uploaded_files, start=1):
|
|
with st.spinner(f"Processing {file.name}..."):
|
|
image = Image.open(file).convert('RGB')
|
|
st.image(image, caption=f'Uploaded Image: {file.name}', use_column_width=True)
|
|
model = load_geoloc_model()
|
|
if model:
|
|
result = predict_location(image, model)
|
|
if result:
|
|
gps_degrees, location_query, city_geojson, processing_time = result
|
|
st.write(
|
|
f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}")
|
|
if city_geojson:
|
|
display_map(city_geojson, gps_degrees)
|
|
st.write(f"Processing Time (seconds): {processing_time}")
|
|
|
|
|
|
def social_media_page() -> None:
|
|
"""
|
|
Display the social media analysis page.
|
|
"""
|
|
st.header("Social Media Analyser")
|
|
social_media_url = st.text_input("Enter a social media URL to analyse:", key='social_media_url_input')
|
|
if social_media_url:
|
|
media_data = get_media(social_media_url)
|
|
if media_data:
|
|
full_text = media_data[0][1]
|
|
st.subheader("Full Text")
|
|
st.write(full_text)
|
|
most_used_location, most_common_locations = most_frequent_locations(full_text)
|
|
st.subheader("Most Frequent Location")
|
|
st.write(most_used_location)
|
|
|
|
for idx, (media_url, _) in enumerate(media_data, start=1):
|
|
st.subheader(f"Image {idx}")
|
|
response = requests.get(media_url)
|
|
if response.status_code == 200:
|
|
image = Image.open(BytesIO(response.content)).convert('RGB')
|
|
st.image(image, caption=f'Image from URL: {media_url}', use_column_width=True)
|
|
model = load_geoloc_model()
|
|
if model:
|
|
result = predict_location(image, model)
|
|
if result:
|
|
gps_degrees, location_query, city_geojson, processing_time = result
|
|
location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}"
|
|
st.write(
|
|
f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}")
|
|
if city_geojson:
|
|
display_map(city_geojson, gps_degrees)
|
|
st.write(f"Processing Time (seconds): {processing_time}")
|
|
if check_location_match(location_query, most_common_locations):
|
|
st.success(
|
|
f"The predicted location {location_name} matches one of the most frequently mentioned locations!")
|
|
else:
|
|
st.error(f"Failed to fetch image at URL {media_url}: HTTP {response.status_code}")
|
|
|
|
|
|
def web_page_url_page() -> None:
|
|
"""
|
|
Display the web page URL analysis page.
|
|
"""
|
|
st.header("Web Page Analyser")
|
|
web_page_url = st.text_input("Enter a web page URL to scrape:", key='web_page_url_input')
|
|
if web_page_url:
|
|
text, images = scrape_webpage(web_page_url)
|
|
if text:
|
|
st.subheader("Extracted Text First 500 Characters:")
|
|
st.write(text[:500])
|
|
most_used_location, most_common_locations = most_frequent_locations(text)
|
|
st.subheader("Most Frequent Location")
|
|
st.write(most_used_location)
|
|
if images:
|
|
selected_image_url = st.selectbox("Select an image to predict location:", images)
|
|
if selected_image_url:
|
|
response = requests.get(selected_image_url)
|
|
if response.status_code == 200:
|
|
image = Image.open(BytesIO(response.content)).convert('RGB')
|
|
st.image(image, caption=f'Selected Image from URL: {selected_image_url}', use_column_width=True)
|
|
model = load_geoloc_model()
|
|
if model:
|
|
result = predict_location(image, model)
|
|
if result:
|
|
gps_degrees, location_query, city_geojson, processing_time = result
|
|
location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}"
|
|
st.write(
|
|
f"City: {location_query['name']}, Region: {location_query['admin1']}, Country: {location_query['cc']}")
|
|
if city_geojson:
|
|
display_map(city_geojson, gps_degrees)
|
|
st.write(f"Processing Time (seconds): {processing_time}")
|
|
if check_location_match(location_query, most_common_locations):
|
|
st.success(
|
|
f"The predicted location {location_name} matches one of the most frequently mentioned locations!")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|