Save and Load State#
Problem#
Long-running quantum simulations (VQE, TDVP, QAOA) may run for hours or days. Checkpointing is essential to:
Prevent data loss from crashes, preemption, or hardware failures
Resume interrupted optimizations without restarting from scratch
Share results with collaborators or across compute sessions
Debug incrementally by analyzing intermediate states
Manage HPC resources by splitting work across multiple jobs
This guide covers strategies for checkpointing MPS states, variational parameters, optimization histories, and distributed simulations.
See also
How to Handle Large Quantum Systems for distributed checkpointing, Debug Simulations for loading checkpoints for debugging, Design Decisions for serialization format details.
Prerequisites#
You need:
ATLAS-Q with MPS, VQE, TDVP, or QAOA configured
Storage space for checkpoints (MPS with n=50, χ=256 ≈ 0.5-2 GB per checkpoint)
Understanding of your optimization problem structure
Strategies#
Strategy 1: Basic MPS Checkpointing#
Save and load MPS state for later analysis or resumption.
Save complete MPS state:
import torch
from atlas_q.adaptive_mps import AdaptiveMPS
# Run simulation
mps = AdaptiveMPS(num_qubits=50, bond_dim=64, device='cuda')
# ... apply gates, run algorithm ...
# Save complete state
checkpoint = {
'tensors': [t.cpu() for t in mps.tensors],
'num_qubits': mps.num_qubits,
'bond_dims': mps.bond_dims,
'statistics': {
'truncation_error': mps.statistics.total_truncation_error,
'max_bond_dim': mps.statistics.max_bond_dim,
'num_operations': mps.statistics.num_operations
},
'metadata': {
'timestamp': torch.tensor([time.time()]),
'atlas_q_version': '0.6.1'
}
}
torch.save(checkpoint, 'mps_checkpoint.pt')
print(f"Saved MPS state: {sum(t.numel() for t in mps.tensors)} elements")
Load MPS state:
# Load from disk
checkpoint = torch.load('mps_checkpoint.pt')
# Reconstruct MPS
mps = AdaptiveMPS(
num_qubits=checkpoint['num_qubits'],
bond_dim=max(checkpoint['bond_dims']), # Use max bond dim
device='cuda'
)
# Restore tensors and bond dimensions
mps.tensors = [t.to('cuda') for t in checkpoint['tensors']]
mps.bond_dims = checkpoint['bond_dims']
# Restore statistics
mps.statistics.total_truncation_error = checkpoint['statistics']['truncation_error']
mps.statistics.max_bond_dim = checkpoint['statistics']['max_bond_dim']
print(f"Loaded MPS state: {mps.num_qubits} qubits, max χ={max(mps.bond_dims)}")
Strategy 2: Periodic Checkpointing for Long Simulations#
Save checkpoints at regular intervals during time evolution or optimization.
Checkpoint TDVP time evolution:
from atlas_q.tdvp import TDVP
import os
# Configure TDVP
tdvp = TDVP(
hamiltonian=H,
mps=mps,
dt=0.01,
device='cuda'
)
# Checkpoint directory
checkpoint_dir = 'tdvp_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
# Evolve with periodic checkpointing
checkpoint_interval = 100 # Steps
total_steps = 10000
energies = []
for step in range(total_steps):
# Single TDVP step
E = tdvp.evolve_step()
energies.append(E)
# Checkpoint every N steps
if (step + 1) % checkpoint_interval == 0:
checkpoint_path = os.path.join(
checkpoint_dir,
f'tdvp_step_{step+1:06d}.pt'
)
torch.save({
'step': step + 1,
'tensors': [t.cpu() for t in tdvp.mps.tensors],
'bond_dims': tdvp.mps.bond_dims,
'energy': E,
'energies': energies.copy(),
'time': (step + 1) * tdvp.dt
}, checkpoint_path)
print(f"[Step {step+1}/{total_steps}] E={E:.6f}, "
f"checkpoint saved to {checkpoint_path}")
print(f"Completed {total_steps} steps, {total_steps // checkpoint_interval} checkpoints saved")
Resume TDVP from checkpoint:
import glob
# Find latest checkpoint
checkpoints = sorted(glob.glob(os.path.join(checkpoint_dir, 'tdvp_step_*.pt')))
if not checkpoints:
print("No checkpoints found, starting from scratch")
start_step = 0
energies = []
else:
latest_checkpoint = checkpoints[-1]
print(f"Resuming from {latest_checkpoint}")
# Load checkpoint
checkpoint = torch.load(latest_checkpoint)
start_step = checkpoint['step']
energies = checkpoint['energies']
# Restore MPS
mps = AdaptiveMPS(
num_qubits=len(checkpoint['tensors']),
bond_dim=max(checkpoint['bond_dims']),
device='cuda'
)
mps.tensors = [t.to('cuda') for t in checkpoint['tensors']]
mps.bond_dims = checkpoint['bond_dims']
# Recreate TDVP
tdvp = TDVP(hamiltonian=H, mps=mps, dt=0.01, device='cuda')
print(f"Resumed from step {start_step}, E={checkpoint['energy']:.6f}")
# Continue evolution
for step in range(start_step, total_steps):
E = tdvp.evolve_step()
energies.append(E)
# ... checkpoint as before ...
Strategy 3: Adaptive Checkpointing (Save on Improvement)#
Save checkpoints only when optimization metrics improve, reducing storage overhead.
Checkpoint VQE when energy improves:
from atlas_q.vqe_qaoa import VQE, VQEConfig
import numpy as np
# Configure VQE with callback
config = VQEConfig(
max_iterations=1000,
optimizer='adam',
learning_rate=0.01
)
vqe = VQE(hamiltonian=H, config=config, device='cuda')
# Track best energy and parameters
best_energy = np.inf
best_params = None
checkpoint_count = 0
def checkpoint_callback(params, energy, iteration):
"""Save checkpoint when energy improves."""
global best_energy, best_params, checkpoint_count
if energy < best_energy - 1e-6: # Improvement threshold
best_energy = energy
best_params = params.clone()
checkpoint_count += 1
torch.save({
'iteration': iteration,
'energy': energy,
'params': params.cpu(),
'energies': vqe.energies.copy(),
'param_history': [p.cpu() for p in vqe.param_history],
'timestamp': time.time()
}, f'vqe_best_checkpoint_{checkpoint_count:04d}.pt')
print(f"[Iter {iteration}] New best energy: {energy:.8f} (saved checkpoint)")
else:
print(f"[Iter {iteration}] Energy: {energy:.8f} (no improvement)")
# Run VQE with adaptive checkpointing
energy, params = vqe.optimize(callback=checkpoint_callback)
print(f"Final energy: {energy:.8f}")
print(f"Total checkpoints saved: {checkpoint_count}")
Resume VQE with warm start:
# Load best checkpoint
checkpoint_files = glob.glob('vqe_best_checkpoint_*.pt')
if checkpoint_files:
latest = sorted(checkpoint_files)[-1]
checkpoint = torch.load(latest)
print(f"Warm starting VQE from checkpoint at E={checkpoint['energy']:.8f}")
# Use previous best parameters as initial guess
initial_params = checkpoint['params'].to('cuda')
# Continue optimization
energy, params = vqe.optimize(initial_params=initial_params)
else:
# Cold start
energy, params = vqe.optimize()
Strategy 4: Compressed Checkpoints#
Reduce checkpoint size using compression and precision reduction.
Compressed MPS checkpoint:
import gzip
import pickle
def save_compressed_mps(mps, path, precision='float32'):
"""
Save MPS with compression.
Parameters
----------
mps : AdaptiveMPS
MPS to save
path : str
Output path (will append .gz)
precision : str
'float32' or 'float16' for reduced precision
"""
# Convert tensors to CPU and reduce precision if requested
if precision == 'float16':
tensors = [t.cpu().half() for t in mps.tensors]
else:
tensors = [t.cpu().float() for t in mps.tensors]
checkpoint = {
'tensors': tensors,
'num_qubits': mps.num_qubits,
'bond_dims': mps.bond_dims,
'precision': precision
}
# Serialize and compress
serialized = pickle.dumps(checkpoint, protocol=pickle.HIGHEST_PROTOCOL)
with gzip.open(path + '.gz', 'wb', compresslevel=6) as f:
f.write(serialized)
# Report compression ratio
uncompressed_size = len(serialized) / 1024**2
compressed_size = os.path.getsize(path + '.gz') / 1024**2
compression_ratio = uncompressed_size / compressed_size
print(f"Saved compressed MPS: {compressed_size:.2f} MB "
f"(compression ratio: {compression_ratio:.2f}x)")
def load_compressed_mps(path, device='cuda'):
"""Load compressed MPS checkpoint."""
with gzip.open(path, 'rb') as f:
serialized = f.read()
checkpoint = pickle.loads(serialized)
# Restore MPS
mps = AdaptiveMPS(
num_qubits=checkpoint['num_qubits'],
bond_dim=max(checkpoint['bond_dims']),
device=device
)
# Convert back to complex64 on target device
mps.tensors = [t.to(torch.complex64).to(device) for t in checkpoint['tensors']]
mps.bond_dims = checkpoint['bond_dims']
print(f"Loaded compressed MPS (original precision: {checkpoint['precision']})")
return mps
# Usage
save_compressed_mps(mps, 'mps_checkpoint.pt', precision='float32')
mps_restored = load_compressed_mps('mps_checkpoint.pt.gz', device='cuda')
Storage comparison:
# Example: n=50, χ=128
# Uncompressed complex64: ~200 MB
# Compressed complex64 (gzip): ~50-80 MB (3-4x reduction)
# Compressed float32: ~25-40 MB (5-8x reduction)
# Compressed float16: ~12-20 MB (10-16x reduction)
# Trade-off: float16 loses precision but may be acceptable for checkpointing
Strategy 5: Distributed State Checkpointing#
Save and load MPS distributed across multiple GPUs or nodes.
Checkpoint distributed MPS (bond-parallel):
from atlas_q.distributed_mps import DistributedMPS, DistributedConfig
import torch.distributed as dist
# Distributed MPS across 4 GPUs
config = DistributedConfig(
mode='bond_parallel',
world_size=4,
backend='nccl',
device_ids=[0, 1, 2, 3]
)
mps = DistributedMPS(
num_qubits=80,
bond_dim=512, # Split across 4 GPUs = 128 per GPU
config=config
)
# Each rank saves its partition
rank = dist.get_rank()
# Save local partition
local_checkpoint = {
'rank': rank,
'tensors': [t.cpu() for t in mps.local_tensors],
'bond_dims': mps.local_bond_dims,
'num_qubits': mps.num_qubits,
'world_size': config.world_size
}
torch.save(local_checkpoint, f'distributed_mps_rank_{rank}.pt')
# Rank 0 saves metadata
if rank == 0:
metadata = {
'num_qubits': mps.num_qubits,
'global_bond_dim': mps.bond_dim,
'world_size': config.world_size,
'mode': config.mode
}
torch.save(metadata, 'distributed_mps_metadata.pt')
dist.barrier() # Synchronize
print(f"Rank {rank}: saved local partition")
Load distributed MPS checkpoint:
# Load metadata
metadata = torch.load('distributed_mps_metadata.pt')
# Reconstruct DistributedConfig
config = DistributedConfig(
mode=metadata['mode'],
world_size=metadata['world_size'],
backend='nccl',
device_ids=list(range(metadata['world_size']))
)
# Create distributed MPS
mps = DistributedMPS(
num_qubits=metadata['num_qubits'],
bond_dim=metadata['global_bond_dim'],
config=config
)
# Each rank loads its partition
rank = dist.get_rank()
local_checkpoint = torch.load(f'distributed_mps_rank_{rank}.pt')
mps.local_tensors = [t.to(f'cuda:{rank}') for t in local_checkpoint['tensors']]
mps.local_bond_dims = local_checkpoint['bond_dims']
dist.barrier()
print(f"Rank {rank}: loaded local partition")
Strategy 6: Version-Compatible Checkpoints#
Ensure checkpoints remain loadable across ATLAS-Q versions.
Save with version metadata:
import atlas_q
def save_versioned_checkpoint(mps, path, extra_metadata=None):
"""
Save checkpoint with version metadata for compatibility.
Parameters
----------
mps : AdaptiveMPS
MPS to save
path : str
Checkpoint path
extra_metadata : dict, optional
Additional metadata to save
"""
checkpoint = {
# MPS data
'tensors': [t.cpu() for t in mps.tensors],
'num_qubits': mps.num_qubits,
'bond_dims': mps.bond_dims,
# Version information
'atlas_q_version': atlas_q.__version__,
'torch_version': torch.__version__,
'checkpoint_format_version': '1.0',
# Save timestamp
'timestamp': time.time(),
'timestamp_readable': time.strftime('%Y-%m-%d %H:%M:%S'),
# Extra metadata
'metadata': extra_metadata or {}
}
torch.save(checkpoint, path)
print(f"Saved checkpoint with ATLAS-Q v{atlas_q.__version__}")
def load_versioned_checkpoint(path, device='cuda'):
"""Load checkpoint with version compatibility checks."""
checkpoint = torch.load(path)
# Check version compatibility
saved_version = checkpoint.get('atlas_q_version', 'unknown')
current_version = atlas_q.__version__
if saved_version != current_version:
print(f"Warning: Checkpoint saved with ATLAS-Q v{saved_version}, "
f"loading with v{current_version}")
# Could implement migration logic here
# Check format version
format_version = checkpoint.get('checkpoint_format_version', '0.0')
if format_version != '1.0':
raise ValueError(f"Unsupported checkpoint format: {format_version}")
# Load MPS
mps = AdaptiveMPS(
num_qubits=checkpoint['num_qubits'],
bond_dim=max(checkpoint['bond_dims']),
device=device
)
mps.tensors = [t.to(device) for t in checkpoint['tensors']]
mps.bond_dims = checkpoint['bond_dims']
print(f"Loaded checkpoint from {checkpoint['timestamp_readable']}")
return mps, checkpoint.get('metadata', {})
# Usage
save_versioned_checkpoint(mps, 'mps_v1.pt', extra_metadata={'experiment': 'qaoa_maxcut'})
mps_loaded, metadata = load_versioned_checkpoint('mps_v1.pt')
Strategy 7: Checkpointing QAOA Progress#
Save intermediate QAOA results for warm-starting and analysis.
QAOA checkpointing with layer tracking:
from atlas_q.vqe_qaoa import QAOA, QAOAConfig
# Configure QAOA
config = QAOAConfig(
p=5, # 5 layers
optimizer='adam',
learning_rate=0.02,
max_iterations=500
)
qaoa = QAOA(hamiltonian=H, config=config, device='cuda')
# Checkpoint directory
checkpoint_dir = 'qaoa_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
# Optimize with checkpointing every 50 iterations
iteration_count = 0
def qaoa_checkpoint_callback(params, energy, iteration):
"""Checkpoint QAOA progress."""
global iteration_count
iteration_count = iteration
if iteration % 50 == 0:
# Extract gamma and beta parameters
p = config.p
gamma = params[:p]
beta = params[p:]
checkpoint = {
'iteration': iteration,
'energy': energy,
'params': params.cpu(),
'gamma': gamma.cpu(),
'beta': beta.cpu(),
'p': p,
'energies': qaoa.energies.copy(),
'param_history': [p.cpu() for p in qaoa.param_history]
}
torch.save(checkpoint, os.path.join(checkpoint_dir, f'qaoa_iter_{iteration:04d}.pt'))
print(f"[Iter {iteration}] E={energy:.6f}, checkpoint saved")
# Run QAOA
energy, params = qaoa.optimize(callback=qaoa_checkpoint_callback)
# Save final result
torch.save({
'final_energy': energy,
'final_params': params.cpu(),
'all_energies': qaoa.energies,
'total_iterations': iteration_count
}, os.path.join(checkpoint_dir, 'qaoa_final.pt'))
Troubleshooting#
Checkpoint Loading Fails#
Problem: torch.load() raises RuntimeError: storage has wrong size.
Solution: Checkpoint was saved on different device or with different dtype.
# Always load to CPU first, then move to device
checkpoint = torch.load('checkpoint.pt', map_location='cpu')
mps.tensors = [t.to('cuda') for t in checkpoint['tensors']]
Out of Memory When Loading#
Problem: Loading large checkpoint exhausts CPU or GPU memory.
Solution: Load tensors incrementally or use memory mapping.
# Load tensors one at a time
checkpoint = torch.load('checkpoint.pt', map_location='cpu')
mps = AdaptiveMPS(num_qubits=checkpoint['num_qubits'], bond_dim=64, device='cuda')
for i, tensor in enumerate(checkpoint['tensors']):
mps.tensors[i] = tensor.to('cuda')
# Clear CPU cache
del tensor
if i % 10 == 0:
torch.cuda.empty_cache()
Distributed Checkpoint Mismatch#
Problem: Loading distributed checkpoint fails with world_size mismatch.
Solution: Saved with 4 GPUs, loading with 2 GPUs requires resharding.
# Load all partitions to CPU
metadata = torch.load('distributed_mps_metadata.pt')
all_tensors = []
for rank in range(metadata['world_size']):
partition = torch.load(f'distributed_mps_rank_{rank}.pt')
all_tensors.extend(partition['tensors'])
# Reconstruct as single-device MPS
mps = AdaptiveMPS(
num_qubits=metadata['num_qubits'],
bond_dim=metadata['global_bond_dim'],
device='cuda'
)
mps.tensors = [t.to('cuda') for t in all_tensors]
# Now redistribute if needed
Checkpoint Size Too Large#
Problem: 50 GB checkpoints fill disk quickly.
Solution: Use compression, reduce precision, or checkpoint less frequently.
# Reduce checkpoint frequency
checkpoint_interval = 200 # Instead of 50
# Use compressed float32 instead of complex64
save_compressed_mps(mps, 'checkpoint.pt', precision='float32')
# Keep only last N checkpoints
max_checkpoints = 5
checkpoints = sorted(glob.glob('checkpoint_*.pt'))
if len(checkpoints) > max_checkpoints:
for old_checkpoint in checkpoints[:-max_checkpoints]:
os.remove(old_checkpoint)
print(f"Removed old checkpoint: {old_checkpoint}")
Summary#
Checkpointing strategies for ATLAS-Q simulations:
Basic checkpointing: Save/load MPS tensors, bond dimensions, and statistics
Periodic checkpointing: Save at fixed intervals during long TDVP/VQE runs
Adaptive checkpointing: Save only when metrics improve to reduce storage
Compressed checkpoints: Use gzip + reduced precision for 5-16x size reduction
Distributed checkpointing: Save partitioned state across multiple devices
Version-compatible checkpoints: Include metadata for cross-version compatibility
Algorithm-specific checkpointing: QAOA, VQE, TDVP with resume capabilities
Choose checkpointing frequency based on:
Simulation length: Longer runs need more frequent checkpoints
Storage budget: Compression vs. precision trade-off
Failure rate: Unreliable hardware needs more checkpoints
Restart cost: Expensive initialization favors more checkpoints
See Also#
How to Handle Large Quantum Systems: Distributed MPS checkpointing for multi-GPU systems
Debug Simulations: Load checkpoints to debug intermediate states
How to Optimize Performance: Checkpoint I/O performance optimization
Design Decisions: Serialization format and compatibility
VQE Tutorial: VQE checkpointing examples
TDVP Tutorial: TDVP time evolution checkpointing