Spaces:
Running
on
Zero
Running
on
Zero
""" | |
visualization.py | |
Routines for plotting VQE results and molecular visualization using Plotly | |
""" | |
import plotly.graph_objects as go | |
import numpy as np | |
import logging | |
import json | |
# Configure logging | |
# Set root logger to DEBUG level | |
logging.getLogger().setLevel(logging.DEBUG) | |
# Configure module logger | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
# Create handlers with DEBUG level | |
console_handler = logging.StreamHandler() | |
console_handler.setLevel(logging.DEBUG) | |
file_handler = logging.FileHandler('molecule_creation.log') | |
file_handler.setLevel(logging.DEBUG) | |
# Create and set formatter | |
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
console_handler.setFormatter(formatter) | |
file_handler.setFormatter(formatter) | |
# Add handlers to logger | |
logger.addHandler(console_handler) | |
logger.addHandler(file_handler) | |
# Atom display properties - Enhanced colors for better visibility | |
ATOM_COLORS = { | |
'H': '#E8E8E8', # Light gray - more visible than pure white | |
'Be': '#42F5EC', # Bright turquoise | |
'Li': '#FF34B3', # Hot pink | |
'B': '#FFB347', # Pastel orange | |
'C': '#808080', # Classic gray | |
'N': '#4169E1', # Royal blue | |
'O': '#FF4500', # Orange red | |
'F': '#DAA520', # Goldenrod | |
'Ne': '#FF1493', # Deep pink | |
} | |
# Physical atomic radii in Ångströms (Å) | |
ATOM_RADII = { | |
"H": 0.25, | |
"He": 0.31, | |
"Li": 1.45, | |
"Be": 1.05, | |
"B": 0.85, | |
"C": 0.70, | |
"N": 0.65, | |
"O": 0.60, | |
"F": 0.50, | |
"Ne": 0.38, | |
} | |
# Use atomic radii directly for sizes (they're already in Å) | |
ATOM_SIZES = ATOM_RADII | |
# Bond display properties | |
BOND_STYLE = dict( | |
color='#2F4F4F', # Dark slate gray | |
width=8, # Thicker for better visibility | |
dash='solid' # Solid lines for bonds | |
) | |
def format_molecule_params(molecule_data: dict) -> str: | |
""" | |
Format molecule parameters into a readable HTML string. | |
Args: | |
molecule_data: Dictionary containing molecule parameters | |
Returns: | |
HTML formatted string of parameters | |
""" | |
try: | |
# Get GPU time in minutes | |
gpu_minutes = molecule_data.get('GPU_time', 60) / 60 | |
# Create a formatted HTML string with molecule parameters | |
params_html = f""" | |
<div style='font-family: monospace; padding: 10px;'> | |
<h3>Molecule Parameters:</h3> | |
<ul style='list-style-type: none; padding-left: 0;'> | |
<li><b>Name:</b> {molecule_data.get('name', 'N/A')} - {molecule_data.get('formula', 'N/A')}</li> | |
<li><b>Electrons:</b> {molecule_data.get('electron_count', 'N/A')}</li> | |
<li><b>Basis:</b> {molecule_data.get('basis', 'N/A')}, <b>Spatial Orbitals:</b> {molecule_data.get('spatial_orbitals', 'N/A')}</li> | |
<li><b>Charge:</b> {molecule_data.get('charge', 'N/A')}</li> | |
<li><b>Simulation time: up to {gpu_minutes:.1f} minutes</b></li> | |
</ul> | |
</div> | |
""" | |
return params_html | |
except Exception as e: | |
logger.error(f"Error formatting molecule parameters: {e}") | |
return "<div>Error: Could not format molecule parameters</div>" | |
def create_molecule_viewer(molecule_id: str, scale_factor: float) -> go.Figure: | |
""" | |
Create an enhanced 3D visualization of the molecule using Plotly. | |
Uses actual atomic radii for sizing and bond determination. | |
Renders atoms as actual spheres using Mesh3d to maintain physical units (Ångströms). | |
Args: | |
molecule_id: Molecule identifier (e.g., "H2") | |
scale_factor: Factor to scale the molecule geometry by (1.0 = original size) | |
Returns: | |
Plotly Figure object with 3D molecule visualization | |
""" | |
logger.info(f"Creating enhanced 3D Plotly view for {molecule_id} with scale_factor={scale_factor}") | |
try: | |
# Load molecule data | |
with open('molecules.json', 'r') as f: | |
molecules = json.load(f) | |
if molecule_id not in molecules: | |
logger.error(f"Unknown molecule {molecule_id}") | |
return go.Figure() | |
molecule_data = molecules[molecule_id] | |
if 'geometry_template' not in molecule_data: | |
logger.error(f"No geometry template found for {molecule_id}") | |
return go.Figure() | |
geometry = molecule_data['geometry_template'] | |
# Scale geometry positions | |
scaled_geometry = [] | |
for atom in geometry: | |
symbol = atom[0] | |
pos = [coord * scale_factor for coord in atom[1]] | |
scaled_geometry.append([symbol, pos]) | |
# Create figure | |
fig = go.Figure() | |
# Function to create sphere mesh points | |
def create_sphere(radius, center, resolution=20): | |
phi = np.linspace(0, 2*np.pi, resolution) | |
theta = np.linspace(-np.pi/2, np.pi/2, resolution) | |
phi, theta = np.meshgrid(phi, theta) | |
x = center[0] + radius * np.cos(theta) * np.cos(phi) | |
y = center[1] + radius * np.cos(theta) * np.sin(phi) | |
z = center[2] + radius * np.sin(theta) | |
return x, y, z | |
# Add atoms as actual spheres | |
for atom in scaled_geometry: | |
symbol = atom[0] | |
pos = atom[1] | |
radius = ATOM_RADII.get(symbol, 0.5) | |
color = ATOM_COLORS.get(symbol, '#808080') | |
# Create sphere mesh | |
x, y, z = create_sphere(radius, pos) | |
# Add sphere | |
fig.add_trace(go.Surface( | |
x=x, y=y, z=z, | |
colorscale=[[0, color], [1, color]], | |
showscale=False, | |
opacity=0.85, | |
hoverinfo='text', | |
hovertext=f"{symbol} at ({pos[0]:.2f}, {pos[1]:.2f}, {pos[2]:.2f})", | |
name=symbol | |
)) | |
# Add atom label | |
fig.add_trace(go.Scatter3d( | |
x=[pos[0]], | |
y=[pos[1]], | |
z=[pos[2]], | |
mode='text', | |
text=[symbol], | |
textposition='middle center', | |
textfont=dict( | |
size=18, # Increased from 14 | |
color='#1A1A1A', # Dark charcoal instead of black | |
family='Arial Black' | |
), | |
showlegend=False | |
)) | |
# Add bonds between atoms | |
positions = np.array([atom[1] for atom in scaled_geometry]) | |
symbols = [atom[0] for atom in scaled_geometry] | |
for i in range(len(symbols)): | |
for j in range(i + 1, len(symbols)): | |
# Calculate distance between atoms | |
dist = np.linalg.norm(positions[i] - positions[j]) | |
# Add bond if distance is less than sum of atomic radii plus a small tolerance | |
radii_sum = (ATOM_RADII.get(symbols[i], 0.5) + ATOM_RADII.get(symbols[j], 0.5)) * 1.3 | |
if dist <= radii_sum * scale_factor: | |
fig.add_trace(go.Scatter3d( | |
x=[positions[i, 0], positions[j, 0]], | |
y=[positions[i, 1], positions[j, 1]], | |
z=[positions[i, 2], positions[j, 2]], | |
mode='lines', | |
line=BOND_STYLE, | |
hoverinfo='none', | |
showlegend=False | |
)) | |
# Update layout for better visualization | |
fig.update_layout( | |
scene=dict( | |
aspectmode='data', | |
xaxis=dict(showspikes=False, showbackground=False, showticklabels=False, title=''), | |
yaxis=dict(showspikes=False, showbackground=False, showticklabels=False, title=''), | |
zaxis=dict(showspikes=False, showbackground=False, showticklabels=False, title=''), | |
camera=dict( | |
up=dict(x=0, y=1, z=0), | |
center=dict(x=0, y=0, z=0), | |
eye=dict(x=1.5, y=1.5, z=1.5) | |
) | |
), | |
margin=dict(l=0, r=0, t=0, b=0), | |
showlegend=False, | |
width=600, | |
height=600, | |
paper_bgcolor='rgba(0,0,0,0)', | |
plot_bgcolor='rgba(0,0,0,0)' | |
) | |
return fig | |
except Exception as e: | |
logger.error(f"Failed to create molecule viewer: {e}", exc_info=True) | |
return go.Figure() | |
def plot_convergence(results): | |
""" | |
Create a convergence plot from VQE iteration history using Plotly. | |
Returns Plotly figure | |
Args: | |
results: Dictionary containing VQE results including history of energy values, | |
or list of energy values directly | |
""" | |
logger.info(f"Plotting convergence with results type: {type(results)}") | |
# Extract iteration numbers and energies | |
iterations = [] | |
energies = [] | |
# Handle different input formats | |
if isinstance(results, dict): | |
history = results.get('history', []) | |
else: | |
history = results | |
logger.info(f"History type: {type(history)}, length: {len(history)}") | |
# Handle both dictionary entries and direct energy values | |
for i, entry in enumerate(history): | |
if isinstance(entry, dict): | |
try: | |
iterations.append(entry.get('iteration', i)) | |
energies.append(entry['energy']) | |
except Exception as e: | |
logger.warning(f"Skipping invalid entry {i}: {str(e)}") | |
continue | |
else: | |
try: | |
iterations.append(i) | |
energies.append(float(entry)) | |
except (ValueError, TypeError) as e: | |
logger.warning(f"Skipping invalid entry {i}: {str(e)}") | |
continue | |
if not iterations or not energies: | |
raise ValueError("No valid iteration data found in results") | |
# Create Plotly figure | |
fig = go.Figure() | |
# Add energy convergence line | |
fig.add_trace(go.Scatter( | |
x=iterations, | |
y=energies, | |
mode='lines+markers', | |
name='Energy', | |
line=dict(color='blue', width=2), | |
marker=dict(size=8) | |
)) | |
# Add final energy line if available | |
final_energy = None | |
if isinstance(results, dict) and 'final_energy' in results: | |
try: | |
final_energy = float(results['final_energy']) | |
except (ValueError, TypeError): | |
pass | |
if final_energy is None and energies: | |
final_energy = energies[-1] | |
if final_energy is not None: | |
fig.add_hline( | |
y=final_energy, | |
line_dash="dash", | |
line_color="red", | |
annotation_text=f"Final Energy: {final_energy:.6f}", | |
annotation_position="bottom right" | |
) | |
# Update layout | |
fig.update_layout( | |
title='VQE Convergence', | |
xaxis_title='Iteration', | |
yaxis_title='Energy (Hartree)', | |
showlegend=True, | |
hovermode='x', | |
width=800, | |
height=600, | |
template='plotly_white' | |
) | |
return fig |