|
""" |
|
visualization.py |
|
Routines for plotting VQE results and molecular visualization using Plotly |
|
""" |
|
|
|
import plotly.graph_objects as go |
|
import numpy as np |
|
import logging |
|
import json |
|
|
|
|
|
|
|
logging.getLogger().setLevel(logging.DEBUG) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.DEBUG) |
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
console_handler.setLevel(logging.DEBUG) |
|
file_handler = logging.FileHandler('molecule_creation.log') |
|
file_handler.setLevel(logging.DEBUG) |
|
|
|
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
console_handler.setFormatter(formatter) |
|
file_handler.setFormatter(formatter) |
|
|
|
|
|
logger.addHandler(console_handler) |
|
logger.addHandler(file_handler) |
|
|
|
|
|
ATOM_COLORS = { |
|
'H': '#E8E8E8', |
|
'Be': '#42F5EC', |
|
'Li': '#FF34B3', |
|
'B': '#FFB347', |
|
'C': '#808080', |
|
'N': '#4169E1', |
|
'O': '#FF4500', |
|
'F': '#DAA520', |
|
'Ne': '#FF1493', |
|
} |
|
|
|
|
|
ATOM_RADII = { |
|
"H": 0.25, |
|
"He": 0.31, |
|
"Li": 1.45, |
|
"Be": 1.05, |
|
"B": 0.85, |
|
"C": 0.70, |
|
"N": 0.65, |
|
"O": 0.60, |
|
"F": 0.50, |
|
"Ne": 0.38, |
|
} |
|
|
|
|
|
ATOM_SIZES = ATOM_RADII |
|
|
|
|
|
BOND_STYLE = dict( |
|
color='#2F4F4F', |
|
width=8, |
|
dash='solid' |
|
) |
|
|
|
def format_molecule_params(molecule_data: dict) -> str: |
|
""" |
|
Format molecule parameters into a readable HTML string. |
|
|
|
Args: |
|
molecule_data: Dictionary containing molecule parameters |
|
Returns: |
|
HTML formatted string of parameters |
|
""" |
|
try: |
|
|
|
gpu_minutes = molecule_data.get('GPU_time', 60) / 60 |
|
|
|
|
|
params_html = f""" |
|
<div style='font-family: monospace; padding: 10px;'> |
|
<h3>Molecule Parameters:</h3> |
|
<ul style='list-style-type: none; padding-left: 0;'> |
|
<li><b>Name:</b> {molecule_data.get('name', 'N/A')} - {molecule_data.get('formula', 'N/A')}</li> |
|
<li><b>Electrons:</b> {molecule_data.get('electron_count', 'N/A')}</li> |
|
<li><b>Basis:</b> {molecule_data.get('basis', 'N/A')}, <b>Spatial Orbitals:</b> {molecule_data.get('spatial_orbitals', 'N/A')}</li> |
|
<li><b>Charge:</b> {molecule_data.get('charge', 'N/A')}</li> |
|
<li><b>Simulation time: up to {gpu_minutes:.1f} minutes</b></li> |
|
</ul> |
|
</div> |
|
""" |
|
return params_html |
|
except Exception as e: |
|
logger.error(f"Error formatting molecule parameters: {e}") |
|
return "<div>Error: Could not format molecule parameters</div>" |
|
|
|
def create_molecule_viewer(molecule_id: str, scale_factor: float) -> go.Figure: |
|
""" |
|
Create an enhanced 3D visualization of the molecule using Plotly. |
|
Uses actual atomic radii for sizing and bond determination. |
|
Renders atoms as actual spheres using Mesh3d to maintain physical units (Ångströms). |
|
|
|
Args: |
|
molecule_id: Molecule identifier (e.g., "H2") |
|
scale_factor: Factor to scale the molecule geometry by (1.0 = original size) |
|
Returns: |
|
Plotly Figure object with 3D molecule visualization |
|
""" |
|
logger.info(f"Creating enhanced 3D Plotly view for {molecule_id} with scale_factor={scale_factor}") |
|
try: |
|
|
|
with open('molecules.json', 'r') as f: |
|
molecules = json.load(f) |
|
if molecule_id not in molecules: |
|
logger.error(f"Unknown molecule {molecule_id}") |
|
return go.Figure() |
|
|
|
molecule_data = molecules[molecule_id] |
|
if 'geometry_template' not in molecule_data: |
|
logger.error(f"No geometry template found for {molecule_id}") |
|
return go.Figure() |
|
|
|
geometry = molecule_data['geometry_template'] |
|
|
|
|
|
scaled_geometry = [] |
|
for atom in geometry: |
|
symbol = atom[0] |
|
pos = [coord * scale_factor for coord in atom[1]] |
|
scaled_geometry.append([symbol, pos]) |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
def create_sphere(radius, center, resolution=20): |
|
phi = np.linspace(0, 2*np.pi, resolution) |
|
theta = np.linspace(-np.pi/2, np.pi/2, resolution) |
|
phi, theta = np.meshgrid(phi, theta) |
|
|
|
x = center[0] + radius * np.cos(theta) * np.cos(phi) |
|
y = center[1] + radius * np.cos(theta) * np.sin(phi) |
|
z = center[2] + radius * np.sin(theta) |
|
|
|
return x, y, z |
|
|
|
|
|
for atom in scaled_geometry: |
|
symbol = atom[0] |
|
pos = atom[1] |
|
radius = ATOM_RADII.get(symbol, 0.5) |
|
color = ATOM_COLORS.get(symbol, '#808080') |
|
|
|
|
|
x, y, z = create_sphere(radius, pos) |
|
|
|
|
|
fig.add_trace(go.Surface( |
|
x=x, y=y, z=z, |
|
colorscale=[[0, color], [1, color]], |
|
showscale=False, |
|
opacity=0.85, |
|
hoverinfo='text', |
|
hovertext=f"{symbol} at ({pos[0]:.2f}, {pos[1]:.2f}, {pos[2]:.2f})", |
|
name=symbol |
|
)) |
|
|
|
|
|
fig.add_trace(go.Scatter3d( |
|
x=[pos[0]], |
|
y=[pos[1]], |
|
z=[pos[2]], |
|
mode='text', |
|
text=[symbol], |
|
textposition='middle center', |
|
textfont=dict( |
|
size=18, |
|
color='#1A1A1A', |
|
family='Arial Black' |
|
), |
|
showlegend=False |
|
)) |
|
|
|
|
|
positions = np.array([atom[1] for atom in scaled_geometry]) |
|
symbols = [atom[0] for atom in scaled_geometry] |
|
|
|
for i in range(len(symbols)): |
|
for j in range(i + 1, len(symbols)): |
|
|
|
dist = np.linalg.norm(positions[i] - positions[j]) |
|
|
|
radii_sum = (ATOM_RADII.get(symbols[i], 0.5) + ATOM_RADII.get(symbols[j], 0.5)) * 1.3 |
|
if dist <= radii_sum * scale_factor: |
|
fig.add_trace(go.Scatter3d( |
|
x=[positions[i, 0], positions[j, 0]], |
|
y=[positions[i, 1], positions[j, 1]], |
|
z=[positions[i, 2], positions[j, 2]], |
|
mode='lines', |
|
line=BOND_STYLE, |
|
hoverinfo='none', |
|
showlegend=False |
|
)) |
|
|
|
|
|
fig.update_layout( |
|
scene=dict( |
|
aspectmode='data', |
|
xaxis=dict(showspikes=False, showbackground=False, showticklabels=False, title=''), |
|
yaxis=dict(showspikes=False, showbackground=False, showticklabels=False, title=''), |
|
zaxis=dict(showspikes=False, showbackground=False, showticklabels=False, title=''), |
|
camera=dict( |
|
up=dict(x=0, y=1, z=0), |
|
center=dict(x=0, y=0, z=0), |
|
eye=dict(x=1.5, y=1.5, z=1.5) |
|
) |
|
), |
|
margin=dict(l=0, r=0, t=0, b=0), |
|
showlegend=False, |
|
width=600, |
|
height=600, |
|
paper_bgcolor='rgba(0,0,0,0)', |
|
plot_bgcolor='rgba(0,0,0,0)' |
|
) |
|
|
|
return fig |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to create molecule viewer: {e}", exc_info=True) |
|
return go.Figure() |
|
|
|
def plot_convergence(results): |
|
""" |
|
Create a convergence plot from VQE iteration history using Plotly. |
|
Returns Plotly figure |
|
|
|
Args: |
|
results: Dictionary containing VQE results including history of energy values, |
|
or list of energy values directly |
|
""" |
|
logger.info(f"Plotting convergence with results type: {type(results)}") |
|
|
|
|
|
iterations = [] |
|
energies = [] |
|
|
|
|
|
if isinstance(results, dict): |
|
history = results.get('history', []) |
|
else: |
|
history = results |
|
|
|
logger.info(f"History type: {type(history)}, length: {len(history)}") |
|
|
|
|
|
for i, entry in enumerate(history): |
|
if isinstance(entry, dict): |
|
try: |
|
iterations.append(entry.get('iteration', i)) |
|
energies.append(entry['energy']) |
|
except Exception as e: |
|
logger.warning(f"Skipping invalid entry {i}: {str(e)}") |
|
continue |
|
else: |
|
try: |
|
iterations.append(i) |
|
energies.append(float(entry)) |
|
except (ValueError, TypeError) as e: |
|
logger.warning(f"Skipping invalid entry {i}: {str(e)}") |
|
continue |
|
|
|
if not iterations or not energies: |
|
raise ValueError("No valid iteration data found in results") |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
x=iterations, |
|
y=energies, |
|
mode='lines+markers', |
|
name='Energy', |
|
line=dict(color='blue', width=2), |
|
marker=dict(size=8) |
|
)) |
|
|
|
|
|
final_energy = None |
|
if isinstance(results, dict) and 'final_energy' in results: |
|
try: |
|
final_energy = float(results['final_energy']) |
|
except (ValueError, TypeError): |
|
pass |
|
|
|
if final_energy is None and energies: |
|
final_energy = energies[-1] |
|
|
|
if final_energy is not None: |
|
fig.add_hline( |
|
y=final_energy, |
|
line_dash="dash", |
|
line_color="red", |
|
annotation_text=f"Final Energy: {final_energy:.6f}", |
|
annotation_position="bottom right" |
|
) |
|
|
|
|
|
fig.update_layout( |
|
title='VQE Convergence', |
|
xaxis_title='Iteration', |
|
yaxis_title='Energy (Hartree)', |
|
showlegend=True, |
|
hovermode='x', |
|
width=800, |
|
height=600, |
|
template='plotly_white' |
|
) |
|
|
|
return fig |