"""
Adaptive Matrix Product State for Moderate-to-High Entanglement
Extends MatrixProductStatePyTorch with:
- Adaptive bond dimension by tolerance
- Per-bond χ caps and global memory budget
- Mixed precision (complex32/complex64) support
- Two-site gate application with automatic SVD truncation
- Comprehensive logging and diagnostics
Mathematical guarantees:
- Local error control: ε_local² = Σ_{i>k} σ_i² ≤ ε_bond²
- Global error bound: ε_global ≤ sqrt(Σ_b ε²_local,b)
- Entropy: S_b = -Σ_i p_i log(p_i) where p_i = σ_i²/Σ_j σ_j²
Author: ATLAS-Q Contributors
Date: October 2025
License: MIT
"""
import math
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
from .diagnostics import MPSStatistics
from .linalg_robust import robust_qr, robust_svd
from .mps_pytorch import MatrixProductStatePyTorch
from .truncation import check_entropy_sanity, choose_rank_from_sigma
# Triton-accelerated gate operations (if available)
try:
from triton_kernels.mps_complex import fused_two_qubit_gate_pytorch, fused_two_qubit_gate_triton
TRITON_AVAILABLE = True
except ImportError:
TRITON_AVAILABLE = False
fused_two_qubit_gate_pytorch = None
[docs]
@dataclass
class DTypePolicy:
"""Mixed precision policy configuration"""
default: torch.dtype = torch.complex64
promote_if_cond_gt: float = 1e6 # Promote to complex128 if cond(S) exceeds this
[docs]
class AdaptiveMPS(MatrixProductStatePyTorch):
"""
Adaptive MPS for moderate-to-high entanglement simulation
Key features:
- Variable per-bond dimensions with adaptive truncation
- Energy-based rank selection: keep k such that Σ_{i≤k} σ_i² ≥ (1-ε²) Σ_i σ_i²
- Per-bond χ caps and global memory budget enforcement
- Mixed precision with automatic promotion on numerical instability
- Two-site gate application (TEBD-style)
- Comprehensive statistics and error tracking
Example:
>>> mps = AdaptiveMPS(16, bond_dim=8, eps_bond=1e-6, chi_max_per_bond=64)
>>> H = torch.tensor([[1,1],[1,-1]], dtype=torch.complex64)/torch.sqrt(torch.tensor(2.0))
>>> for q in range(16):
>>> mps.apply_single_qubit_gate(q, H)
>>> CZ = torch.diag(torch.tensor([1,1,1,-1], dtype=torch.complex64))
>>> for i in range(0, 15, 2):
>>> mps.apply_two_site_gate(i, CZ)
>>> print(mps.stats_summary())
"""
[docs]
def __init__(
self,
num_qubits: int,
bond_dim: int = 8,
*,
eps_bond: float = 1e-6,
chi_max_per_bond: Optional[Union[int, List[int]]] = 256,
budget_global_mb: Optional[float] = None,
dtype_policy: DTypePolicy = DTypePolicy(),
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
):
"""
Initialize Adaptive MPS
Args:
num_qubits: Number of qubits
bond_dim: Initial bond dimension
eps_bond: Energy tolerance for truncation (default 1e-6)
chi_max_per_bond: Max χ per bond (int or list of ints)
budget_global_mb: Global memory budget in MB (None = unlimited)
dtype_policy: Mixed precision policy
device: 'cuda' or 'cpu'
dtype: Explicit dtype (overrides dtype_policy.default if provided)
"""
# Use explicit dtype if provided, otherwise fall back to dtype_policy
if dtype is None:
dtype = dtype_policy.default
super().__init__(num_qubits, bond_dim, device, dtype)
self.eps_bond = eps_bond
self.dtype_policy = dtype_policy
self.budget_global_mb = budget_global_mb
# Per-bond χ caps
if isinstance(chi_max_per_bond, int):
self.chi_max_per_bond = [chi_max_per_bond] * (num_qubits - 1)
else:
assert (
len(chi_max_per_bond) == num_qubits - 1
), f"chi_max_per_bond must have length {num_qubits-1}"
self.chi_max_per_bond = list(chi_max_per_bond)
# Track actual bond dimensions (initially uniform)
self.bond_dims = [bond_dim] * (num_qubits - 1)
# Statistics tracker
self.statistics = MPSStatistics()
self._operation_counter = 0
# Initialize to computational zero state |00...0⟩
self._initialize_zero_state()
def _initialize_zero_state(self):
"""Initialize MPS to computational zero state |00...0⟩"""
# For |00...0⟩, only the [*, 0, *] entry of each tensor is non-zero
# All tensors: set [:, 1, :] = 0 and [:, 0, :] = identity-like
for i in range(self.num_qubits):
T = self.tensors[i]
# Zero out the |1⟩ component
T[:, 1, :] = 0
# Set |0⟩ component to identity (or close to it)
T[:, 0, :] = torch.eye(T.shape[0], T.shape[2], dtype=T.dtype, device=T.device)
# Normalize
self._normalize()
def _log_operation(self, **kwargs):
"""Log an operation to statistics"""
self.statistics.record(step=self._operation_counter, **kwargs)
self._operation_counter += 1
[docs]
def stats_summary(self) -> Dict[str, float]:
"""Get summary statistics"""
return self.statistics.summary()
[docs]
def global_error_bound(self) -> float:
"""Get global error upper bound"""
return self.statistics.global_error_bound()
[docs]
def reset_stats(self):
"""Reset statistics tracking"""
self.statistics.reset()
self._operation_counter = 0
[docs]
@torch.no_grad()
def apply_single_qubit_gate(self, q: int, U2: torch.Tensor):
"""
Apply single-qubit gate (fast path, no truncation needed)
Args:
q: Qubit index
U2: 2x2 unitary gate
Complexity: O(χ²)
"""
assert 0 <= q < self.num_qubits, f"Qubit index {q} out of range"
assert U2.shape == (2, 2), "U2 must be 2x2"
T = self.tensors[q]
# Contract: T[a,s,b] * U[s,t] -> T[a,t,b]
# Move gate to same device and dtype as tensor
U2_device = U2.to(device=T.device, dtype=T.dtype)
self.tensors[q] = torch.einsum("st,asb->atb", U2_device, T)
# Mark MPS as non-canonical after gate application
self.is_canonical = False
[docs]
@torch.no_grad()
def apply_two_site_gate(self, i: int, U4: torch.Tensor):
"""
Apply two-qubit gate with adaptive SVD truncation
This is the core TEBD operation for moderate entanglement.
Args:
i: Bond index (applies to qubits i and i+1)
U4: 4x4 unitary gate (or 2x2x2x2 tensor)
Steps:
1. Merge tensors at sites (i, i+1) into Θ
2. Apply gate U
3. SVD: Θ = U S V†
4. Adaptively select rank k by energy criterion + caps
5. Split back into two cores with updated χ
Complexity: O(χ³) for SVD
"""
assert 0 <= i < self.num_qubits - 1, f"Bond index {i} out of range"
start_time = time.time()
A, B = self.tensors[i], self.tensors[i + 1]
χL, χM, χR = A.shape[0], A.shape[2], B.shape[2]
device = A.device
# Choose dtype (start with policy default, may promote later)
local_dtype = self.dtype_policy.default
A = A.to(dtype=local_dtype)
B = B.to(dtype=local_dtype)
# Reshape U4 to 4x4 if needed
if U4.shape == (2, 2, 2, 2):
U_matrix = U4.reshape(4, 4)
elif U4.shape == (4, 4):
U_matrix = U4
else:
raise ValueError(f"U4 must be (2,2,2,2) or (4,4), got {U4.shape}")
U_matrix = U_matrix.to(device=device, dtype=local_dtype)
# Steps 1-3: Fused gate application (using Triton if available on CUDA)
# This fuses: merge tensors + apply gate + reshape for SVD
use_triton = TRITON_AVAILABLE and device.type == "cuda"
if use_triton:
try:
# Use Triton-accelerated fused kernel (1.5-3× faster)
# Input: A[li,2,ri], B[ri,2,rj], U[4,4]
# Output: X[li*2, 2*rj] ready for SVD
X = fused_two_qubit_gate_triton(A, B, U_matrix)
except Exception:
# Fall back to PyTorch if Triton fails
X = fused_two_qubit_gate_pytorch(A, B, U_matrix)
else:
# Standard PyTorch path (no Triton available or on CPU)
if fused_two_qubit_gate_pytorch is not None:
X = fused_two_qubit_gate_pytorch(A, B, U_matrix)
else:
# Fallback: manual einsum operations
Theta = torch.einsum("asm,mtb->astb", A, B) # [χL, 2, 2, χR]
U = U_matrix.view(2, 2, 2, 2)
Theta_new = torch.einsum("stuv,astb->auvb", U, Theta)
X = Theta_new.reshape(χL * 2, 2 * χR)
# Step 4: SVD with fallback
U, S, Vh, driver = robust_svd(X)
# Step 5: Adaptive rank selection
cap = self.chi_max_per_bond[i]
def budget_ok(k: int) -> bool:
if self.budget_global_mb is None:
return True
# Estimate memory delta
bytes_per_elem = torch.finfo(local_dtype).bits // 8
before = (χL * χM + χM * χR) * 2 * bytes_per_elem
after = (χL * k + k * χR) * 2 * bytes_per_elem
delta = max(0, after - before)
current_mb = self.memory_usage() / (1024**2)
return (current_mb + delta / (1024**2)) < self.budget_global_mb
k, eps_loc, entropy, condS = choose_rank_from_sigma(S, self.eps_bond, cap, budget_ok)
# Step 6: Check if we need to promote precision
if math.isfinite(condS) and condS > self.dtype_policy.promote_if_cond_gt:
# Promote to complex128 and recompute
X_promoted = X.to(torch.complex128)
U, S, Vh, driver = robust_svd(X_promoted)
k, eps_loc, entropy, condS = choose_rank_from_sigma(S, self.eps_bond, cap, budget_ok)
local_dtype = torch.complex128
# Step 7: Sanity check entropy
if not check_entropy_sanity(entropy, χL, χR):
print(f"Warning: Entropy {entropy:.3f} exceeds physical bound at bond {i}")
# Step 8: Rebuild cores
US = U[:, :k] * S[:k] # [2χL, k]
VhK = Vh[:k, :] # [k, 2χR]
A_new = US.reshape(χL, 2, k) # [χL, 2, k]
B_new = VhK.reshape(k, 2, χR) # [k, 2, χR]
# Step 9: Update tensors
self.tensors[i] = A_new
self.tensors[i + 1] = B_new
self.bond_dims[i] = k
# Step 10: Log operation
elapsed_ms = (time.time() - start_time) * 1000
self._log_operation(
bond=i,
k_star=k,
chi_before=χM,
chi_after=k,
eps_local=eps_loc,
entropy=entropy,
svd_driver=driver,
dtype=str(local_dtype),
ms_elapsed=elapsed_ms,
condS=condS,
)
# Mark MPS as non-canonical after gate application
self.is_canonical = False
[docs]
@torch.no_grad()
def to_left_canonical(self):
"""
Bring MPS into left-canonical form using QR
After this, each tensor A^[i] satisfies:
Σ_s (A^[i]_s)† A^[i]_s = I
Complexity: O(n · χ³)
"""
for i in range(self.num_qubits - 1):
A = self.tensors[i]
χL, p, χR = A.shape
# QR decomposition
Q, R, _ = robust_qr(A.reshape(χL * p, χR))
χmid = Q.shape[1]
# Update this tensor
self.tensors[i] = Q.reshape(χL, p, χmid)
# Absorb R into next tensor
B = self.tensors[i + 1]
self.tensors[i + 1] = torch.einsum("ij,jkl->ikl", R, B)
# Update bond dimension (if bond_dims exists - may not during __init__)
if hasattr(self, 'bond_dims') and i < len(self.bond_dims):
self.bond_dims[i] = χmid
self.is_canonical = True
# Alias for compatibility with base class
[docs]
def canonicalize_left_to_right(self):
"""Alias for to_left_canonical() for compatibility"""
self.to_left_canonical()
[docs]
@torch.no_grad()
def to_mixed_canonical(self, center: int):
"""
Bring MPS into mixed-canonical form with center at specified site
Sites 0..center-1 are left-canonical
Sites center+1..n-1 are right-canonical
Site center holds the normalization
Args:
center: Center site index
Complexity: O(n · χ³)
"""
assert 0 <= center < self.num_qubits, "Center out of range"
# Left-canonicalize up to center
for i in range(center):
A = self.tensors[i]
χL, p, χR = A.shape
Q, R, _ = robust_qr(A.reshape(χL * p, χR))
χmid = Q.shape[1]
self.tensors[i] = Q.reshape(χL, p, χmid)
B = self.tensors[i + 1]
self.tensors[i + 1] = torch.einsum("ij,jkl->ikl", R, B)
# Right-canonicalize from end down to center+1
for i in range(self.num_qubits - 1, center, -1):
B = self.tensors[i]
χL, p, χR = B.shape
# Reshape to [χL, p·χR] for QR
B_mat = B.permute(1, 2, 0).reshape(p * χR, χL)
Q, R, _ = robust_qr(B_mat)
χmid = Q.shape[1]
# Reshape back
self.tensors[i] = Q.t().reshape(χmid, p, χR)
# Absorb R into previous tensor
if i > 0:
A = self.tensors[i - 1]
self.tensors[i - 1] = torch.einsum("ijk,kl->ijl", A, R.t())
[docs]
def snapshot(self, path: str):
"""
Save MPS to file for checkpointing
Args:
path: File path to save to
"""
torch.save(
{
"tensors": [t.cpu() for t in self.tensors],
"bond_dims": self.bond_dims,
"num_qubits": self.num_qubits,
"eps_bond": self.eps_bond,
"chi_max_per_bond": self.chi_max_per_bond,
"statistics": self.statistics.logs,
},
path,
)
[docs]
@staticmethod
def load_snapshot(path: str, device: str = "cuda") -> "AdaptiveMPS":
"""
Load MPS from checkpoint file
Args:
path: File path to load from
device: Device to place tensors on
Returns:
Loaded AdaptiveMPS instance
"""
data = torch.load(path)
mps = AdaptiveMPS(
num_qubits=data["num_qubits"],
bond_dim=max(data["bond_dims"]),
eps_bond=data["eps_bond"],
chi_max_per_bond=data["chi_max_per_bond"],
device=device,
)
mps.tensors = [t.to(device) for t in data["tensors"]]
mps.bond_dims = data["bond_dims"]
if "statistics" in data:
mps.statistics.logs = data["statistics"]
return mps
[docs]
def get_memory_usage(self) -> int:
"""
Get total memory usage in bytes
Returns:
Total bytes used by all tensors
"""
return self.memory_usage()
[docs]
@torch.no_grad()
def to_statevector(self) -> torch.Tensor:
"""
Convert MPS to full statevector (ONLY for small systems!)
Returns:
Full statevector of size 2^n
Warning:
This scales as O(2^n) and should ONLY be used for testing
small systems (n ≤ 20). For larger systems, use get_amplitude().
"""
if self.num_qubits > 20:
raise ValueError(
f"Converting {self.num_qubits} qubits to statevector "
f"would require {2**self.num_qubits} amplitudes. "
f"Use get_amplitude() for large systems."
)
# Contract all tensors: T[0] * T[1] * ... * T[n-1]
# Determine target dtype (use the highest precision dtype present)
target_dtype = self.dtype if hasattr(self, 'dtype') else self.tensors[0].dtype
result = self.tensors[0].to(dtype=target_dtype) # [1, 2, χ₁]
for i in range(1, self.num_qubits):
# result: [..., χᵢ]
# T[i]: [χᵢ, 2, χᵢ₊₁]
# Contract over χᵢ dimension
# Ensure both tensors have same dtype
tensor_i = self.tensors[i].to(dtype=target_dtype)
result = torch.einsum("...i,ijk->...jk", result, tensor_i)
# Final tensor: [2, 2, ..., 2, 1]
# Squeeze the trailing dimension and flatten
result = result.squeeze(-1).reshape(-1)
return result
# ===========================================================================
# Named Gate Methods (Qiskit/Cirq Compatibility)
# ===========================================================================
# These methods provide convenient wrappers that use the Triton-optimized
# apply_single_qubit_gate() and apply_two_site_gate() methods
[docs]
def h(self, qubit: int):
"""Hadamard gate"""
H = torch.tensor([[1, 1], [1, -1]], dtype=self.dtype, device=self.device) / torch.sqrt(torch.tensor(2.0, device=self.device))
self.apply_single_qubit_gate(qubit, H)
[docs]
def x(self, qubit: int):
"""Pauli X gate"""
X = torch.tensor([[0, 1], [1, 0]], dtype=self.dtype, device=self.device)
self.apply_single_qubit_gate(qubit, X)
[docs]
def y(self, qubit: int):
"""Pauli Y gate"""
Y = torch.tensor([[0, -1j], [1j, 0]], dtype=self.dtype, device=self.device)
self.apply_single_qubit_gate(qubit, Y)
[docs]
def z(self, qubit: int):
"""Pauli Z gate"""
Z = torch.tensor([[1, 0], [0, -1]], dtype=self.dtype, device=self.device)
self.apply_single_qubit_gate(qubit, Z)
[docs]
def s(self, qubit: int):
"""Phase gate (S gate)"""
S = torch.tensor([[1, 0], [0, 1j]], dtype=self.dtype, device=self.device)
self.apply_single_qubit_gate(qubit, S)
[docs]
def sdg(self, qubit: int):
"""S dagger gate"""
Sdg = torch.tensor([[1, 0], [0, -1j]], dtype=self.dtype, device=self.device)
self.apply_single_qubit_gate(qubit, Sdg)
[docs]
def t(self, qubit: int):
"""T gate"""
T = torch.tensor([[1, 0], [0, torch.exp(1j * torch.tensor(torch.pi / 4, device=self.device))]], dtype=self.dtype, device=self.device)
self.apply_single_qubit_gate(qubit, T)
[docs]
def tdg(self, qubit: int):
"""T dagger gate"""
Tdg = torch.tensor([[1, 0], [0, torch.exp(-1j * torch.tensor(torch.pi / 4, device=self.device))]], dtype=self.dtype, device=self.device)
self.apply_single_qubit_gate(qubit, Tdg)
[docs]
def rx(self, qubit: int, theta: float):
"""Rotation around X axis"""
Rx = torch.tensor([
[torch.cos(theta/2), -1j * torch.sin(theta/2)],
[-1j * torch.sin(theta/2), torch.cos(theta/2)]
], dtype=self.dtype, device=self.device)
self.apply_single_qubit_gate(qubit, Rx)
[docs]
def ry(self, qubit: int, theta: float):
"""Rotation around Y axis"""
Ry = torch.tensor([
[torch.cos(theta/2), -torch.sin(theta/2)],
[torch.sin(theta/2), torch.cos(theta/2)]
], dtype=self.dtype, device=self.device)
self.apply_single_qubit_gate(qubit, Ry)
[docs]
def rz(self, qubit: int, theta: float):
"""Rotation around Z axis"""
Rz = torch.tensor([
[torch.exp(-1j * theta/2), 0],
[0, torch.exp(1j * theta/2)]
], dtype=self.dtype, device=self.device)
self.apply_single_qubit_gate(qubit, Rz)
[docs]
def cnot(self, control: int, target: int):
"""CNOT gate (uses Triton-optimized two-site gate)"""
if abs(control - target) != 1:
raise NotImplementedError("Only adjacent qubit gates supported in MPS")
# Ensure control < target for apply_two_site_gate
if control > target:
control, target = target, control
# Swap CNOT: use SWAP + CNOT + SWAP basis transformation
# For now, raise error
raise NotImplementedError("Non-adjacent CNOT or reversed control/target not yet supported")
CNOT = torch.tensor([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
[0, 0, 1, 0]
], dtype=self.dtype, device=self.device)
self.apply_two_site_gate(control, CNOT)
[docs]
def cx(self, control: int, target: int):
"""Alias for CNOT"""
self.cnot(control, target)
[docs]
def cz(self, q0: int, q1: int):
"""Controlled-Z gate"""
if abs(q0 - q1) != 1:
raise NotImplementedError("Only adjacent qubit gates supported in MPS")
q0, q1 = min(q0, q1), max(q0, q1)
CZ = torch.tensor([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, -1]
], dtype=self.dtype, device=self.device)
self.apply_two_site_gate(q0, CZ)
[docs]
def cy(self, control: int, target: int):
"""Controlled-Y gate"""
if abs(control - target) != 1:
raise NotImplementedError("Only adjacent qubit gates supported in MPS")
if control > target:
raise NotImplementedError("Reversed control/target not yet supported")
CY = torch.tensor([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, -1j],
[0, 0, 1j, 0]
], dtype=self.dtype, device=self.device)
self.apply_two_site_gate(control, CY)
[docs]
def swap(self, q0: int, q1: int):
"""SWAP gate"""
if abs(q0 - q1) != 1:
raise NotImplementedError("Only adjacent qubit gates supported in MPS")
q0, q1 = min(q0, q1), max(q0, q1)
SWAP = torch.tensor([
[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]
], dtype=self.dtype, device=self.device)
self.apply_two_site_gate(q0, SWAP)
[docs]
def sample(self, num_shots: int = 1):
"""Sample measurement outcomes - delegates to parent class"""
return self.measure(num_shots)