agen / app.py
latterworks's picture
Update app.py
717a9f0 verified
"""
Ollama Instance & Model Scanner for Hugging Face Space
This application scans for publicly accessible Ollama instances, retrieves model information,
and provides a secure interface for browsing discovered models.
Security Architecture:
- Server-side authorization based on environment variables
- Strict input sanitization
- Comprehensive error handling
- Asynchronous endpoint checking
- Efficient dataset management
"""
import os
import re
import json
import asyncio
import logging
import gradio as gr
import shodan
import aiohttp
from datasets import load_dataset, Dataset
from typing import Dict, List, Optional, Any, Tuple, Union
from datetime import datetime
from functools import wraps
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Security layer - Authorization functions
def authorization_required(func):
"""
Decorator that enforces server-side authorization for protected functions.
Authorization is determined by environment variables, not client parameters.
Args:
func: The function to protect with authorization
Returns:
A wrapped function that performs authorization check
"""
@wraps(func)
def wrapper(*args, **kwargs):
if not verify_admin_authorization():
logger.warning(f"Unauthorized access attempt to {func.__name__}")
return {"error": "Unauthorized access"} if kwargs.get("return_error", False) else None
return func(*args, **kwargs)
return wrapper
def verify_admin_authorization() -> bool:
"""
Perform server-side verification of admin authorization.
Authorization is based on environment variables, not client data.
Returns:
bool: True if valid admin credentials exist
"""
try:
# Check for the existence of the Shodan API key
api_key = os.getenv("SHODAN_API_KEY")
hf_token = os.getenv("HF_TOKEN")
return (api_key is not None and
len(api_key.strip()) > 10 and
hf_token is not None and
len(hf_token.strip()) > 10)
except Exception as e:
logger.error(f"Error verifying admin authorization: {str(e)}")
return False
# Security layer - Input validation
def sanitize_input(input_string: str) -> str:
"""
Sanitize user input to prevent injection attacks.
Args:
input_string: User input string to sanitize
Returns:
str: Sanitized string
"""
if not isinstance(input_string, str):
return ""
# Remove potentially harmful characters
sanitized = re.sub(r'[^\w\s\-\.]', '', input_string)
# Limit length to prevent DoS
return sanitized[:100]
def get_env_variables() -> Dict[str, str]:
"""
Get all required environment variables.
Returns:
Dict[str, str]: Dictionary containing environment variables
Raises:
ValueError: If any required environment variable is missing
"""
env_vars = {
"SHODAN_API_KEY": os.getenv("SHODAN_API_KEY"),
"SHODAN_QUERY": os.getenv("SHODAN_QUERY", "product:Ollama port:11434"),
"HF_TOKEN": os.getenv("HF_TOKEN")
}
missing_vars = [name for name, value in env_vars.items() if not value]
if missing_vars:
error_msg = f"Missing required environment variables: {', '.join(missing_vars)}"
logger.error(error_msg)
raise ValueError(error_msg)
return env_vars
# Data access layer
def load_or_create_dataset() -> Dataset:
"""
Load the dataset from Hugging Face Hub or create it if it doesn't exist.
Returns:
Dataset: Loaded or created dataset
Raises:
Exception: If dataset loading or creation fails
"""
try:
# Attempt to get environment variables - this will raise ValueError if missing
env_vars = get_env_variables()
logger.info("Attempting to load dataset from Hugging Face Hub")
dataset = load_dataset("latterworks/llama_checker_results", use_auth_token=env_vars["HF_TOKEN"])
dataset = dataset['train']
logger.info(f"Successfully loaded dataset with {len(dataset)} entries")
return dataset
except ValueError as e:
# Re-raise environment variable errors
raise
except FileNotFoundError:
# Only create dataset if admin authorization is verified
if not verify_admin_authorization():
logger.error("Unauthorized attempt to create dataset")
raise ValueError("Unauthorized: Only admins can create the dataset")
logger.info("Dataset not found, creating a new one")
env_vars = get_env_variables()
dataset = Dataset.from_dict({
"ip": [],
"port": [],
"country": [],
"region": [],
"org": [],
"models": []
})
dataset.push_to_hub("latterworks/llama_checker_results", token=env_vars["HF_TOKEN"])
logger.info("Created and pushed empty dataset to Hugging Face Hub")
# Reload the dataset to ensure consistency
dataset = load_dataset("latterworks/llama_checker_results", use_auth_token=env_vars["HF_TOKEN"])['train']
return dataset
except Exception as e:
error_msg = f"Failed to load or create dataset: {str(e)}"
logger.error(error_msg)
raise
async def check_single_endpoint(ip: str, port: int, timeout: int = 5) -> Optional[List[Dict[str, Any]]]:
"""
Check a single Ollama endpoint for available models.
Args:
ip: IP address of the Ollama instance
port: Port number of the Ollama instance
timeout: Timeout in seconds for the HTTP request
Returns:
Optional[List[Dict[str, Any]]]: List of model information dictionaries, or None if endpoint check fails
"""
url = f"http://{ip}:{port}/api/tags"
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=timeout) as response:
if response.status == 200:
data = await response.json()
if "models" in data and isinstance(data["models"], list):
logger.info(f"Successfully retrieved {len(data['models'])} models from {ip}:{port}")
return data["models"]
else:
logger.warning(f"Unexpected response format from {ip}:{port}")
else:
logger.warning(f"Received status code {response.status} from {ip}:{port}")
except aiohttp.ClientError as e:
logger.warning(f"Connection error for {ip}:{port}: {str(e)}")
except asyncio.TimeoutError:
logger.warning(f"Connection timeout for {ip}:{port}")
except Exception as e:
logger.warning(f"Unexpected error checking {ip}:{port}: {str(e)}")
return None
@authorization_required
async def check_ollama_endpoints(dataset: Dataset, progress: Optional[gr.Progress] = None) -> Dataset:
"""
Check all Ollama endpoints in the dataset for available models.
Requires admin authorization.
Args:
dataset: Dataset containing Ollama endpoints
progress: Optional Gradio progress bar
Returns:
Dataset: Updated dataset with model information
"""
if progress:
progress(0, desc="Preparing to check endpoints...")
# Build a list of tasks to execute
total_endpoints = len(dataset)
tasks = []
for i, item in enumerate(dataset):
ip = item["ip"]
port = item["port"]
tasks.append(check_single_endpoint(ip, port))
# Execute tasks in batches to avoid overwhelming resources
batch_size = 10
updated_dataset = dataset.copy()
for i in range(0, len(tasks), batch_size):
if progress:
progress(i / len(tasks), desc=f"Checking endpoints {i+1}-{min(i+batch_size, len(tasks))} of {len(tasks)}...")
batch_tasks = tasks[i:i+batch_size]
batch_results = await asyncio.gather(*batch_tasks)
for j, result in enumerate(batch_results):
idx = i + j
if idx < len(dataset):
if result:
updated_dataset = updated_dataset.add_item({
"ip": dataset[idx]["ip"],
"port": dataset[idx]["port"],
"country": dataset[idx]["country"],
"region": dataset[idx]["region"],
"org": dataset[idx]["org"],
"models": result
})
if progress:
progress(1.0, desc="Endpoint checking complete!")
logger.info(f"Checked {total_endpoints} endpoints, found models on {sum(1 for item in updated_dataset if item['models'])} endpoints")
# Push updated dataset to Hugging Face Hub
env_vars = get_env_variables()
updated_dataset.push_to_hub("latterworks/llama_checker_results", token=env_vars["HF_TOKEN"])
logger.info("Successfully pushed updated dataset to Hugging Face Hub")
return updated_dataset
@authorization_required
def scan_shodan(progress: Optional[gr.Progress] = None) -> str:
"""
Scan Shodan for Ollama instances and update the dataset.
Requires admin authorization.
Args:
progress: Optional Gradio progress bar
Returns:
str: Status message
"""
try:
# Get environment variables
env_vars = get_env_variables()
# Load dataset
dataset = load_or_create_dataset()
# Initialize Shodan API client
api = shodan.Shodan(env_vars["SHODAN_API_KEY"])
query = env_vars["SHODAN_QUERY"]
if progress:
progress(0, desc="Starting Shodan search...")
# Get total results count
count_result = api.count(query)
total_results = count_result.get('total', 0)
if total_results == 0:
return "No Ollama instances found on Shodan."
logger.info(f"Found {total_results} potential Ollama instances on Shodan")
# Search Shodan
new_instances = []
results_processed = 0
for result in api.search_cursor(query):
results_processed += 1
if progress:
progress(results_processed / total_results,
desc=f"Processing Shodan result {results_processed}/{total_results}")
ip = result.get('ip_str')
port = result.get('port', 11434)
# Skip if instance already exists in dataset
if any(item["ip"] == ip and item["port"] == port for item in dataset):
continue
# Extract location information
country = result.get('location', {}).get('country_name', '')
region = result.get('location', {}).get('region_name', '')
org = result.get('org', '')
new_instances.append({
"ip": ip,
"port": port,
"country": country,
"region": region,
"org": org,
"models": []
})
if progress:
progress(1.0, desc="Shodan search complete!")
# Add new instances to dataset
updated_dataset = dataset.copy()
for instance in new_instances:
updated_dataset = updated_dataset.add_item(instance)
logger.info(f"Added {len(new_instances)} new instances to dataset")
# Check Ollama endpoints asynchronously
if new_instances:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
updated_dataset = loop.run_until_complete(check_ollama_endpoints(updated_dataset, progress))
loop.close()
status_message = f"Scan complete! Found {len(new_instances)} new Ollama instances."
return status_message
except shodan.APIError as e:
error_msg = f"Shodan API error: {str(e)}"
logger.error(error_msg)
return error_msg
except Exception as e:
error_msg = f"Error during Shodan scan: {str(e)}"
logger.error(error_msg)
return error_msg
def get_unique_values(dataset: Dataset, field: str) -> List[str]:
"""
Get unique values for a specific field in the dataset.
Args:
dataset: Dataset to extract values from
field: Field name to extract values from
Returns:
List[str]: List of unique values
"""
unique_values = set()
if field == "family" or field == "parameter_size" or field == "quantization_level":
for item in dataset:
models = item.get("models", [])
if not models:
continue
for model in models:
details = model.get("details", {})
if details and field in details:
value = details.get(field)
if value:
unique_values.add(value)
return sorted(list(unique_values))
def search_models(dataset: Dataset, name_search: str = "", family: str = "", parameter_size: str = "") -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Search for models in the dataset based on filters.
Authorization is determined server-side.
Args:
dataset: Dataset to search
name_search: Model name search string
family: Model family filter
parameter_size: Parameter size filter
Returns:
Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: Filtered model list and detailed model list
"""
# Server-side authorization check
is_admin = verify_admin_authorization()
name_search = sanitize_input(name_search).lower()
family = sanitize_input(family)
parameter_size = sanitize_input(parameter_size)
filtered_models = []
detailed_models = []
for item in dataset:
models = item.get("models", [])
if not models:
continue
ip = item.get("ip", "")
port = item.get("port", 0)
country = item.get("country", "")
region = item.get("region", "")
org = item.get("org", "")
for model in models:
model_name = model.get("name", "").lower()
details = model.get("details", {})
model_family = details.get("family", "")
model_parameter_size = details.get("parameter_size", "")
model_quantization = details.get("quantization_level", "")
model_size = model.get("size", 0)
model_size_gb = round(model_size / (1024**3), 2) if model_size else 0
# Apply filters
if name_search and name_search not in model_name:
continue
if family and family != model_family:
continue
if parameter_size and parameter_size != model_parameter_size:
continue
# Prepare filtered model entry
filtered_model = {
"name": model.get("name", ""),
"family": model_family,
"parameter_size": model_parameter_size,
"quantization_level": model_quantization,
"size_gb": model_size_gb
}
# Add IP and port information only for admins - server-side check
if is_admin:
filtered_model["ip"] = ip
filtered_model["port"] = port
filtered_models.append(filtered_model)
# Prepare detailed model entry
detailed_model = {
"name": model.get("name", ""),
"family": model_family,
"parameter_size": model_parameter_size,
"quantization_level": model_quantization,
"size_gb": model_size_gb,
"digest": model.get("digest", ""),
"modified_at": model.get("modified_at", ""),
"country": country,
"region": region,
"org": org
}
# Add IP and port information only for admins - server-side check
if is_admin:
detailed_model["ip"] = ip
detailed_model["port"] = port
detailed_models.append(detailed_model)
return filtered_models, detailed_models
def create_ui() -> gr.Blocks:
"""
Create the Gradio user interface with server-side authorization.
Returns:
gr.Blocks: Gradio interface
"""
# Load dataset
try:
dataset = load_or_create_dataset()
except Exception as e:
# Fallback to empty dataset if loading fails
logger.error(f"Failed to load dataset: {str(e)}")
dataset = Dataset.from_dict({
"ip": [],
"port": [],
"country": [],
"region": [],
"org": [],
"models": []
})
# Server-side authorization check
is_admin = verify_admin_authorization()
# Get unique values for dropdowns
families = [""] + get_unique_values(dataset, "family")
parameter_sizes = [""] + get_unique_values(dataset, "parameter_size")
# Initial search results
initial_results, initial_details = search_models(dataset)
with gr.Blocks(title="Ollama Instance & Model Browser") as app:
gr.Markdown("# Ollama Instance & Model Browser")
with gr.Tabs() as tabs:
with gr.Tab("Browse Models"):
with gr.Row():
with gr.Column(scale=1):
name_search = gr.Textbox(label="Model Name Search")
family_dropdown = gr.Dropdown(
choices=families,
label="Model Family",
value=""
)
parameter_size_dropdown = gr.Dropdown(
choices=parameter_sizes,
label="Parameter Size",
value=""
)
search_button = gr.Button("Search Models")
with gr.Row():
model_results = gr.DataFrame(
value=initial_results,
label="Model Results",
interactive=False
)
with gr.Row():
model_details = gr.JSON(label="Model Details")
def search_callback(name, family, parameter_size):
results, details = search_models(dataset, name, family, parameter_size)
return results, None
def select_model(evt: gr.SelectData):
results, details = search_models(dataset, name_search.value,
family_dropdown.value,
parameter_size_dropdown.value)
if evt.index[0] < len(details):
return details[evt.index[0]]
return None
search_button.click(
search_callback,
inputs=[name_search, family_dropdown, parameter_size_dropdown],
outputs=[model_results, model_details]
)
model_results.select(
select_model,
None,
model_details
)
# Only show Shodan Scan tab for admins - server-side check
if is_admin:
with gr.Tab("Shodan Scan"):
gr.Markdown("## Scan for Ollama Instances")
gr.Markdown("**Note:** This scan will update the dataset with new Ollama instances.")
scan_button = gr.Button("Start Scan")
scan_output = gr.Textbox(label="Scan Status")
scan_button.click(
lambda progress=gr.Progress(): scan_shodan(progress),
outputs=scan_output
)
# Refresh dataset when the app starts
def refresh_data():
nonlocal dataset
try:
dataset = load_or_create_dataset()
except Exception as e:
logger.error(f"Failed to refresh dataset: {str(e)}")
# Continue with existing dataset
results, details = search_models(dataset)
return results
app.load(
fn=refresh_data,
outputs=model_results
)
return app
# Main entry point
if __name__ == "__main__":
try:
ui = create_ui()
ui.launch()
except Exception as e:
logger.critical(f"Failed to start application: {str(e)}")