A19grey's picture
Final commit working for release version
3045ca3
"""
visualization.py
Routines for plotting VQE results and molecular visualization using Plotly
"""
import plotly.graph_objects as go
import numpy as np
import logging
import json
# Configure logging
# Set root logger to DEBUG level
logging.getLogger().setLevel(logging.DEBUG)
# Configure module logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# Create handlers with DEBUG level
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
file_handler = logging.FileHandler('molecule_creation.log')
file_handler.setLevel(logging.DEBUG)
# Create and set formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
# Add handlers to logger
logger.addHandler(console_handler)
logger.addHandler(file_handler)
# Atom display properties - Enhanced colors for better visibility
ATOM_COLORS = {
'H': '#E8E8E8', # Light gray - more visible than pure white
'Be': '#42F5EC', # Bright turquoise
'Li': '#FF34B3', # Hot pink
'B': '#FFB347', # Pastel orange
'C': '#808080', # Classic gray
'N': '#4169E1', # Royal blue
'O': '#FF4500', # Orange red
'F': '#DAA520', # Goldenrod
'Ne': '#FF1493', # Deep pink
}
# Physical atomic radii in Ångströms (Å)
ATOM_RADII = {
"H": 0.25,
"He": 0.31,
"Li": 1.45,
"Be": 1.05,
"B": 0.85,
"C": 0.70,
"N": 0.65,
"O": 0.60,
"F": 0.50,
"Ne": 0.38,
}
# Use atomic radii directly for sizes (they're already in Å)
ATOM_SIZES = ATOM_RADII
# Bond display properties
BOND_STYLE = dict(
color='#2F4F4F', # Dark slate gray
width=8, # Thicker for better visibility
dash='solid' # Solid lines for bonds
)
def format_molecule_params(molecule_data: dict) -> str:
"""
Format molecule parameters into a readable HTML string.
Args:
molecule_data: Dictionary containing molecule parameters
Returns:
HTML formatted string of parameters
"""
try:
# Get GPU time in minutes
gpu_minutes = molecule_data.get('GPU_time', 60) / 60
# Create a formatted HTML string with molecule parameters
params_html = f"""
<div style='font-family: monospace; padding: 10px;'>
<h3>Molecule Parameters:</h3>
<ul style='list-style-type: none; padding-left: 0;'>
<li><b>Name:</b> {molecule_data.get('name', 'N/A')} - {molecule_data.get('formula', 'N/A')}</li>
<li><b>Electrons:</b> {molecule_data.get('electron_count', 'N/A')}</li>
<li><b>Basis:</b> {molecule_data.get('basis', 'N/A')}, <b>Spatial Orbitals:</b> {molecule_data.get('spatial_orbitals', 'N/A')}</li>
<li><b>Charge:</b> {molecule_data.get('charge', 'N/A')}</li>
<li><b>Simulation time: up to {gpu_minutes:.1f} minutes</b></li>
</ul>
</div>
"""
return params_html
except Exception as e:
logger.error(f"Error formatting molecule parameters: {e}")
return "<div>Error: Could not format molecule parameters</div>"
def create_molecule_viewer(molecule_id: str, scale_factor: float) -> go.Figure:
"""
Create an enhanced 3D visualization of the molecule using Plotly.
Uses actual atomic radii for sizing and bond determination.
Renders atoms as actual spheres using Mesh3d to maintain physical units (Ångströms).
Args:
molecule_id: Molecule identifier (e.g., "H2")
scale_factor: Factor to scale the molecule geometry by (1.0 = original size)
Returns:
Plotly Figure object with 3D molecule visualization
"""
logger.info(f"Creating enhanced 3D Plotly view for {molecule_id} with scale_factor={scale_factor}")
try:
# Load molecule data
with open('molecules.json', 'r') as f:
molecules = json.load(f)
if molecule_id not in molecules:
logger.error(f"Unknown molecule {molecule_id}")
return go.Figure()
molecule_data = molecules[molecule_id]
if 'geometry_template' not in molecule_data:
logger.error(f"No geometry template found for {molecule_id}")
return go.Figure()
geometry = molecule_data['geometry_template']
# Scale geometry positions
scaled_geometry = []
for atom in geometry:
symbol = atom[0]
pos = [coord * scale_factor for coord in atom[1]]
scaled_geometry.append([symbol, pos])
# Create figure
fig = go.Figure()
# Function to create sphere mesh points
def create_sphere(radius, center, resolution=20):
phi = np.linspace(0, 2*np.pi, resolution)
theta = np.linspace(-np.pi/2, np.pi/2, resolution)
phi, theta = np.meshgrid(phi, theta)
x = center[0] + radius * np.cos(theta) * np.cos(phi)
y = center[1] + radius * np.cos(theta) * np.sin(phi)
z = center[2] + radius * np.sin(theta)
return x, y, z
# Add atoms as actual spheres
for atom in scaled_geometry:
symbol = atom[0]
pos = atom[1]
radius = ATOM_RADII.get(symbol, 0.5)
color = ATOM_COLORS.get(symbol, '#808080')
# Create sphere mesh
x, y, z = create_sphere(radius, pos)
# Add sphere
fig.add_trace(go.Surface(
x=x, y=y, z=z,
colorscale=[[0, color], [1, color]],
showscale=False,
opacity=0.85,
hoverinfo='text',
hovertext=f"{symbol} at ({pos[0]:.2f}, {pos[1]:.2f}, {pos[2]:.2f})",
name=symbol
))
# Add atom label
fig.add_trace(go.Scatter3d(
x=[pos[0]],
y=[pos[1]],
z=[pos[2]],
mode='text',
text=[symbol],
textposition='middle center',
textfont=dict(
size=18, # Increased from 14
color='#1A1A1A', # Dark charcoal instead of black
family='Arial Black'
),
showlegend=False
))
# Add bonds between atoms
positions = np.array([atom[1] for atom in scaled_geometry])
symbols = [atom[0] for atom in scaled_geometry]
for i in range(len(symbols)):
for j in range(i + 1, len(symbols)):
# Calculate distance between atoms
dist = np.linalg.norm(positions[i] - positions[j])
# Add bond if distance is less than sum of atomic radii plus a small tolerance
radii_sum = (ATOM_RADII.get(symbols[i], 0.5) + ATOM_RADII.get(symbols[j], 0.5)) * 1.3
if dist <= radii_sum * scale_factor:
fig.add_trace(go.Scatter3d(
x=[positions[i, 0], positions[j, 0]],
y=[positions[i, 1], positions[j, 1]],
z=[positions[i, 2], positions[j, 2]],
mode='lines',
line=BOND_STYLE,
hoverinfo='none',
showlegend=False
))
# Update layout for better visualization
fig.update_layout(
scene=dict(
aspectmode='data',
xaxis=dict(showspikes=False, showbackground=False, showticklabels=False, title=''),
yaxis=dict(showspikes=False, showbackground=False, showticklabels=False, title=''),
zaxis=dict(showspikes=False, showbackground=False, showticklabels=False, title=''),
camera=dict(
up=dict(x=0, y=1, z=0),
center=dict(x=0, y=0, z=0),
eye=dict(x=1.5, y=1.5, z=1.5)
)
),
margin=dict(l=0, r=0, t=0, b=0),
showlegend=False,
width=600,
height=600,
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)'
)
return fig
except Exception as e:
logger.error(f"Failed to create molecule viewer: {e}", exc_info=True)
return go.Figure()
def plot_convergence(results):
"""
Create a convergence plot from VQE iteration history using Plotly.
Returns Plotly figure
Args:
results: Dictionary containing VQE results including history of energy values,
or list of energy values directly
"""
logger.info(f"Plotting convergence with results type: {type(results)}")
# Extract iteration numbers and energies
iterations = []
energies = []
# Handle different input formats
if isinstance(results, dict):
history = results.get('history', [])
else:
history = results
logger.info(f"History type: {type(history)}, length: {len(history)}")
# Handle both dictionary entries and direct energy values
for i, entry in enumerate(history):
if isinstance(entry, dict):
try:
iterations.append(entry.get('iteration', i))
energies.append(entry['energy'])
except Exception as e:
logger.warning(f"Skipping invalid entry {i}: {str(e)}")
continue
else:
try:
iterations.append(i)
energies.append(float(entry))
except (ValueError, TypeError) as e:
logger.warning(f"Skipping invalid entry {i}: {str(e)}")
continue
if not iterations or not energies:
raise ValueError("No valid iteration data found in results")
# Create Plotly figure
fig = go.Figure()
# Add energy convergence line
fig.add_trace(go.Scatter(
x=iterations,
y=energies,
mode='lines+markers',
name='Energy',
line=dict(color='blue', width=2),
marker=dict(size=8)
))
# Add final energy line if available
final_energy = None
if isinstance(results, dict) and 'final_energy' in results:
try:
final_energy = float(results['final_energy'])
except (ValueError, TypeError):
pass
if final_energy is None and energies:
final_energy = energies[-1]
if final_energy is not None:
fig.add_hline(
y=final_energy,
line_dash="dash",
line_color="red",
annotation_text=f"Final Energy: {final_energy:.6f}",
annotation_position="bottom right"
)
# Update layout
fig.update_layout(
title='VQE Convergence',
xaxis_title='Iteration',
yaxis_title='Energy (Hartree)',
showlegend=True,
hovermode='x',
width=800,
height=600,
template='plotly_white'
)
return fig