Spaces:
Running
Running
""" | |
Ollama Instance & Model Scanner for Hugging Face Space | |
This application scans for publicly accessible Ollama instances, retrieves model information, | |
and provides a secure interface for browsing discovered models. | |
Security Architecture: | |
- Server-side authorization based on environment variables | |
- Strict input sanitization | |
- Comprehensive error handling | |
- Asynchronous endpoint checking | |
- Efficient dataset management | |
""" | |
import os | |
import re | |
import json | |
import asyncio | |
import logging | |
import gradio as gr | |
import shodan | |
import aiohttp | |
from datasets import load_dataset, Dataset | |
from typing import Dict, List, Optional, Any, Tuple, Union | |
from datetime import datetime | |
from functools import wraps | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler()] | |
) | |
logger = logging.getLogger(__name__) | |
# Security layer - Authorization functions | |
def authorization_required(func): | |
""" | |
Decorator that enforces server-side authorization for protected functions. | |
Authorization is determined by environment variables, not client parameters. | |
Args: | |
func: The function to protect with authorization | |
Returns: | |
A wrapped function that performs authorization check | |
""" | |
def wrapper(*args, **kwargs): | |
if not verify_admin_authorization(): | |
logger.warning(f"Unauthorized access attempt to {func.__name__}") | |
return {"error": "Unauthorized access"} if kwargs.get("return_error", False) else None | |
return func(*args, **kwargs) | |
return wrapper | |
def verify_admin_authorization() -> bool: | |
""" | |
Perform server-side verification of admin authorization. | |
Authorization is based on environment variables, not client data. | |
Returns: | |
bool: True if valid admin credentials exist | |
""" | |
try: | |
# Check for the existence of the Shodan API key | |
api_key = os.getenv("SHODAN_API_KEY") | |
hf_token = os.getenv("HF_TOKEN") | |
return (api_key is not None and | |
len(api_key.strip()) > 10 and | |
hf_token is not None and | |
len(hf_token.strip()) > 10) | |
except Exception as e: | |
logger.error(f"Error verifying admin authorization: {str(e)}") | |
return False | |
# Security layer - Input validation | |
def sanitize_input(input_string: str) -> str: | |
""" | |
Sanitize user input to prevent injection attacks. | |
Args: | |
input_string: User input string to sanitize | |
Returns: | |
str: Sanitized string | |
""" | |
if not isinstance(input_string, str): | |
return "" | |
# Remove potentially harmful characters | |
sanitized = re.sub(r'[^\w\s\-\.]', '', input_string) | |
# Limit length to prevent DoS | |
return sanitized[:100] | |
def get_env_variables() -> Dict[str, str]: | |
""" | |
Get all required environment variables. | |
Returns: | |
Dict[str, str]: Dictionary containing environment variables | |
Raises: | |
ValueError: If any required environment variable is missing | |
""" | |
env_vars = { | |
"SHODAN_API_KEY": os.getenv("SHODAN_API_KEY"), | |
"SHODAN_QUERY": os.getenv("SHODAN_QUERY", "product:Ollama port:11434"), | |
"HF_TOKEN": os.getenv("HF_TOKEN") | |
} | |
missing_vars = [name for name, value in env_vars.items() if not value] | |
if missing_vars: | |
error_msg = f"Missing required environment variables: {', '.join(missing_vars)}" | |
logger.error(error_msg) | |
raise ValueError(error_msg) | |
return env_vars | |
# Data access layer | |
def load_or_create_dataset() -> Dataset: | |
""" | |
Load the dataset from Hugging Face Hub or create it if it doesn't exist. | |
Returns: | |
Dataset: Loaded or created dataset | |
Raises: | |
Exception: If dataset loading or creation fails | |
""" | |
try: | |
# Attempt to get environment variables - this will raise ValueError if missing | |
env_vars = get_env_variables() | |
logger.info("Attempting to load dataset from Hugging Face Hub") | |
dataset = load_dataset("latterworks/llama_checker_results", use_auth_token=env_vars["HF_TOKEN"]) | |
dataset = dataset['train'] | |
logger.info(f"Successfully loaded dataset with {len(dataset)} entries") | |
return dataset | |
except ValueError as e: | |
# Re-raise environment variable errors | |
raise | |
except FileNotFoundError: | |
# Only create dataset if admin authorization is verified | |
if not verify_admin_authorization(): | |
logger.error("Unauthorized attempt to create dataset") | |
raise ValueError("Unauthorized: Only admins can create the dataset") | |
logger.info("Dataset not found, creating a new one") | |
env_vars = get_env_variables() | |
dataset = Dataset.from_dict({ | |
"ip": [], | |
"port": [], | |
"country": [], | |
"region": [], | |
"org": [], | |
"models": [] | |
}) | |
dataset.push_to_hub("latterworks/llama_checker_results", token=env_vars["HF_TOKEN"]) | |
logger.info("Created and pushed empty dataset to Hugging Face Hub") | |
# Reload the dataset to ensure consistency | |
dataset = load_dataset("latterworks/llama_checker_results", use_auth_token=env_vars["HF_TOKEN"])['train'] | |
return dataset | |
except Exception as e: | |
error_msg = f"Failed to load or create dataset: {str(e)}" | |
logger.error(error_msg) | |
raise | |
async def check_single_endpoint(ip: str, port: int, timeout: int = 5) -> Optional[List[Dict[str, Any]]]: | |
""" | |
Check a single Ollama endpoint for available models. | |
Args: | |
ip: IP address of the Ollama instance | |
port: Port number of the Ollama instance | |
timeout: Timeout in seconds for the HTTP request | |
Returns: | |
Optional[List[Dict[str, Any]]]: List of model information dictionaries, or None if endpoint check fails | |
""" | |
url = f"http://{ip}:{port}/api/tags" | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.get(url, timeout=timeout) as response: | |
if response.status == 200: | |
data = await response.json() | |
if "models" in data and isinstance(data["models"], list): | |
logger.info(f"Successfully retrieved {len(data['models'])} models from {ip}:{port}") | |
return data["models"] | |
else: | |
logger.warning(f"Unexpected response format from {ip}:{port}") | |
else: | |
logger.warning(f"Received status code {response.status} from {ip}:{port}") | |
except aiohttp.ClientError as e: | |
logger.warning(f"Connection error for {ip}:{port}: {str(e)}") | |
except asyncio.TimeoutError: | |
logger.warning(f"Connection timeout for {ip}:{port}") | |
except Exception as e: | |
logger.warning(f"Unexpected error checking {ip}:{port}: {str(e)}") | |
return None | |
async def check_ollama_endpoints(dataset: Dataset, progress: Optional[gr.Progress] = None) -> Dataset: | |
""" | |
Check all Ollama endpoints in the dataset for available models. | |
Requires admin authorization. | |
Args: | |
dataset: Dataset containing Ollama endpoints | |
progress: Optional Gradio progress bar | |
Returns: | |
Dataset: Updated dataset with model information | |
""" | |
if progress: | |
progress(0, desc="Preparing to check endpoints...") | |
# Build a list of tasks to execute | |
total_endpoints = len(dataset) | |
tasks = [] | |
for i, item in enumerate(dataset): | |
ip = item["ip"] | |
port = item["port"] | |
tasks.append(check_single_endpoint(ip, port)) | |
# Execute tasks in batches to avoid overwhelming resources | |
batch_size = 10 | |
updated_dataset = dataset.copy() | |
for i in range(0, len(tasks), batch_size): | |
if progress: | |
progress(i / len(tasks), desc=f"Checking endpoints {i+1}-{min(i+batch_size, len(tasks))} of {len(tasks)}...") | |
batch_tasks = tasks[i:i+batch_size] | |
batch_results = await asyncio.gather(*batch_tasks) | |
for j, result in enumerate(batch_results): | |
idx = i + j | |
if idx < len(dataset): | |
if result: | |
updated_dataset = updated_dataset.add_item({ | |
"ip": dataset[idx]["ip"], | |
"port": dataset[idx]["port"], | |
"country": dataset[idx]["country"], | |
"region": dataset[idx]["region"], | |
"org": dataset[idx]["org"], | |
"models": result | |
}) | |
if progress: | |
progress(1.0, desc="Endpoint checking complete!") | |
logger.info(f"Checked {total_endpoints} endpoints, found models on {sum(1 for item in updated_dataset if item['models'])} endpoints") | |
# Push updated dataset to Hugging Face Hub | |
env_vars = get_env_variables() | |
updated_dataset.push_to_hub("latterworks/llama_checker_results", token=env_vars["HF_TOKEN"]) | |
logger.info("Successfully pushed updated dataset to Hugging Face Hub") | |
return updated_dataset | |
def scan_shodan(progress: Optional[gr.Progress] = None) -> str: | |
""" | |
Scan Shodan for Ollama instances and update the dataset. | |
Requires admin authorization. | |
Args: | |
progress: Optional Gradio progress bar | |
Returns: | |
str: Status message | |
""" | |
try: | |
# Get environment variables | |
env_vars = get_env_variables() | |
# Load dataset | |
dataset = load_or_create_dataset() | |
# Initialize Shodan API client | |
api = shodan.Shodan(env_vars["SHODAN_API_KEY"]) | |
query = env_vars["SHODAN_QUERY"] | |
if progress: | |
progress(0, desc="Starting Shodan search...") | |
# Get total results count | |
count_result = api.count(query) | |
total_results = count_result.get('total', 0) | |
if total_results == 0: | |
return "No Ollama instances found on Shodan." | |
logger.info(f"Found {total_results} potential Ollama instances on Shodan") | |
# Search Shodan | |
new_instances = [] | |
results_processed = 0 | |
for result in api.search_cursor(query): | |
results_processed += 1 | |
if progress: | |
progress(results_processed / total_results, | |
desc=f"Processing Shodan result {results_processed}/{total_results}") | |
ip = result.get('ip_str') | |
port = result.get('port', 11434) | |
# Skip if instance already exists in dataset | |
if any(item["ip"] == ip and item["port"] == port for item in dataset): | |
continue | |
# Extract location information | |
country = result.get('location', {}).get('country_name', '') | |
region = result.get('location', {}).get('region_name', '') | |
org = result.get('org', '') | |
new_instances.append({ | |
"ip": ip, | |
"port": port, | |
"country": country, | |
"region": region, | |
"org": org, | |
"models": [] | |
}) | |
if progress: | |
progress(1.0, desc="Shodan search complete!") | |
# Add new instances to dataset | |
updated_dataset = dataset.copy() | |
for instance in new_instances: | |
updated_dataset = updated_dataset.add_item(instance) | |
logger.info(f"Added {len(new_instances)} new instances to dataset") | |
# Check Ollama endpoints asynchronously | |
if new_instances: | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
updated_dataset = loop.run_until_complete(check_ollama_endpoints(updated_dataset, progress)) | |
loop.close() | |
status_message = f"Scan complete! Found {len(new_instances)} new Ollama instances." | |
return status_message | |
except shodan.APIError as e: | |
error_msg = f"Shodan API error: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
except Exception as e: | |
error_msg = f"Error during Shodan scan: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
def get_unique_values(dataset: Dataset, field: str) -> List[str]: | |
""" | |
Get unique values for a specific field in the dataset. | |
Args: | |
dataset: Dataset to extract values from | |
field: Field name to extract values from | |
Returns: | |
List[str]: List of unique values | |
""" | |
unique_values = set() | |
if field == "family" or field == "parameter_size" or field == "quantization_level": | |
for item in dataset: | |
models = item.get("models", []) | |
if not models: | |
continue | |
for model in models: | |
details = model.get("details", {}) | |
if details and field in details: | |
value = details.get(field) | |
if value: | |
unique_values.add(value) | |
return sorted(list(unique_values)) | |
def search_models(dataset: Dataset, name_search: str = "", family: str = "", parameter_size: str = "") -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: | |
""" | |
Search for models in the dataset based on filters. | |
Authorization is determined server-side. | |
Args: | |
dataset: Dataset to search | |
name_search: Model name search string | |
family: Model family filter | |
parameter_size: Parameter size filter | |
Returns: | |
Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: Filtered model list and detailed model list | |
""" | |
# Server-side authorization check | |
is_admin = verify_admin_authorization() | |
name_search = sanitize_input(name_search).lower() | |
family = sanitize_input(family) | |
parameter_size = sanitize_input(parameter_size) | |
filtered_models = [] | |
detailed_models = [] | |
for item in dataset: | |
models = item.get("models", []) | |
if not models: | |
continue | |
ip = item.get("ip", "") | |
port = item.get("port", 0) | |
country = item.get("country", "") | |
region = item.get("region", "") | |
org = item.get("org", "") | |
for model in models: | |
model_name = model.get("name", "").lower() | |
details = model.get("details", {}) | |
model_family = details.get("family", "") | |
model_parameter_size = details.get("parameter_size", "") | |
model_quantization = details.get("quantization_level", "") | |
model_size = model.get("size", 0) | |
model_size_gb = round(model_size / (1024**3), 2) if model_size else 0 | |
# Apply filters | |
if name_search and name_search not in model_name: | |
continue | |
if family and family != model_family: | |
continue | |
if parameter_size and parameter_size != model_parameter_size: | |
continue | |
# Prepare filtered model entry | |
filtered_model = { | |
"name": model.get("name", ""), | |
"family": model_family, | |
"parameter_size": model_parameter_size, | |
"quantization_level": model_quantization, | |
"size_gb": model_size_gb | |
} | |
# Add IP and port information only for admins - server-side check | |
if is_admin: | |
filtered_model["ip"] = ip | |
filtered_model["port"] = port | |
filtered_models.append(filtered_model) | |
# Prepare detailed model entry | |
detailed_model = { | |
"name": model.get("name", ""), | |
"family": model_family, | |
"parameter_size": model_parameter_size, | |
"quantization_level": model_quantization, | |
"size_gb": model_size_gb, | |
"digest": model.get("digest", ""), | |
"modified_at": model.get("modified_at", ""), | |
"country": country, | |
"region": region, | |
"org": org | |
} | |
# Add IP and port information only for admins - server-side check | |
if is_admin: | |
detailed_model["ip"] = ip | |
detailed_model["port"] = port | |
detailed_models.append(detailed_model) | |
return filtered_models, detailed_models | |
def create_ui() -> gr.Blocks: | |
""" | |
Create the Gradio user interface with server-side authorization. | |
Returns: | |
gr.Blocks: Gradio interface | |
""" | |
# Load dataset | |
try: | |
dataset = load_or_create_dataset() | |
except Exception as e: | |
# Fallback to empty dataset if loading fails | |
logger.error(f"Failed to load dataset: {str(e)}") | |
dataset = Dataset.from_dict({ | |
"ip": [], | |
"port": [], | |
"country": [], | |
"region": [], | |
"org": [], | |
"models": [] | |
}) | |
# Server-side authorization check | |
is_admin = verify_admin_authorization() | |
# Get unique values for dropdowns | |
families = [""] + get_unique_values(dataset, "family") | |
parameter_sizes = [""] + get_unique_values(dataset, "parameter_size") | |
# Initial search results | |
initial_results, initial_details = search_models(dataset) | |
with gr.Blocks(title="Ollama Instance & Model Browser") as app: | |
gr.Markdown("# Ollama Instance & Model Browser") | |
with gr.Tabs() as tabs: | |
with gr.Tab("Browse Models"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
name_search = gr.Textbox(label="Model Name Search") | |
family_dropdown = gr.Dropdown( | |
choices=families, | |
label="Model Family", | |
value="" | |
) | |
parameter_size_dropdown = gr.Dropdown( | |
choices=parameter_sizes, | |
label="Parameter Size", | |
value="" | |
) | |
search_button = gr.Button("Search Models") | |
with gr.Row(): | |
model_results = gr.DataFrame( | |
value=initial_results, | |
label="Model Results", | |
interactive=False | |
) | |
with gr.Row(): | |
model_details = gr.JSON(label="Model Details") | |
def search_callback(name, family, parameter_size): | |
results, details = search_models(dataset, name, family, parameter_size) | |
return results, None | |
def select_model(evt: gr.SelectData): | |
results, details = search_models(dataset, name_search.value, | |
family_dropdown.value, | |
parameter_size_dropdown.value) | |
if evt.index[0] < len(details): | |
return details[evt.index[0]] | |
return None | |
search_button.click( | |
search_callback, | |
inputs=[name_search, family_dropdown, parameter_size_dropdown], | |
outputs=[model_results, model_details] | |
) | |
model_results.select( | |
select_model, | |
None, | |
model_details | |
) | |
# Only show Shodan Scan tab for admins - server-side check | |
if is_admin: | |
with gr.Tab("Shodan Scan"): | |
gr.Markdown("## Scan for Ollama Instances") | |
gr.Markdown("**Note:** This scan will update the dataset with new Ollama instances.") | |
scan_button = gr.Button("Start Scan") | |
scan_output = gr.Textbox(label="Scan Status") | |
scan_button.click( | |
lambda progress=gr.Progress(): scan_shodan(progress), | |
outputs=scan_output | |
) | |
# Refresh dataset when the app starts | |
def refresh_data(): | |
nonlocal dataset | |
try: | |
dataset = load_or_create_dataset() | |
except Exception as e: | |
logger.error(f"Failed to refresh dataset: {str(e)}") | |
# Continue with existing dataset | |
results, details = search_models(dataset) | |
return results | |
app.load( | |
fn=refresh_data, | |
outputs=model_results | |
) | |
return app | |
# Main entry point | |
if __name__ == "__main__": | |
try: | |
ui = create_ui() | |
ui.launch() | |
except Exception as e: | |
logger.critical(f"Failed to start application: {str(e)}") |