Source code for atlas_q.diagnostics

"""
Diagnostics and Monitoring Utilities for MPS

Provides entropy calculations, statistics tracking, and observability tools.

Author: ATLAS-Q Contributors
Date: October 2025
License: MIT
"""

import math
from typing import Dict, List

import torch


[docs] def bond_entropy_from_S(S: torch.Tensor) -> float: """ Compute entanglement entropy from singular values S = -Σ_i p_i log(p_i) where p_i = σ_i² / Σ_j σ_j² Args: S: Singular values (sorted descending) Returns: Entanglement entropy in bits (log base 2) """ if len(S) == 0: return 0.0 p = (S * S) / (S * S).sum() # Clamp to avoid log(0) entropy = -(p * torch.log2(torch.clamp(p, min=1e-30))).sum() return float(entropy.item())
[docs] def effective_rank(S: torch.Tensor, threshold: float = 0.99) -> int: """ Compute effective rank: smallest k where cumulative energy ≥ threshold Args: S: Singular values threshold: Energy retention threshold (default 99%) Returns: Effective rank """ if len(S) == 0: return 0 E = (S * S).cumsum(0) total = E[-1] k = int(torch.searchsorted(E, threshold * total).item()) + 1 return min(k, len(S))
[docs] def spectral_gap(S: torch.Tensor, k: int) -> float: """ Compute spectral gap σ_k / σ_{k+1} Large gaps indicate safe truncation points. Args: S: Singular values k: Truncation point Returns: Spectral gap ratio (or inf if k+1 doesn't exist or is zero) """ if k >= len(S) or k < 1: return float("inf") if S[k] == 0: return float("inf") return float((S[k - 1] / S[k]).item())
[docs] class MPSStatistics: """ Track and aggregate MPS operation statistics Maintains per-operation logs and rolling aggregates for: - Bond dimensions (χ) - Truncation errors (ε_local) - Entanglement entropies (S) - SVD driver usage - Computation times """
[docs] def __init__(self): self.logs: Dict[str, List] = { "step": [], "bond": [], "k_star": [], "chi_before": [], "chi_after": [], "eps_local": [], "entropy": [], "svd_driver": [], "dtype": [], "ms_elapsed": [], "condS": [], }
[docs] def record(self, **kwargs): """Record a single operation""" for key, value in kwargs.items(): if key in self.logs: self.logs[key].append(value)
[docs] def summary(self) -> Dict[str, float]: """Compute summary statistics""" import numpy as np def safe_agg(key: str, fn): if not self.logs[key]: return 0.0 try: return float(fn(np.array(self.logs[key]))) except: return 0.0 return { "total_operations": len(self.logs["step"]), "max_chi": safe_agg("chi_after", np.max), "mean_chi": safe_agg("chi_after", np.mean), "sum_eps2": safe_agg("eps_local", lambda x: (x**2).sum()), "max_eps": safe_agg("eps_local", np.max), "mean_entropy": safe_agg("entropy", np.mean), "p95_entropy": safe_agg("entropy", lambda x: np.percentile(x, 95)), "total_time_ms": safe_agg("ms_elapsed", np.sum), "cuda_svd_pct": self._driver_percentage("torch_cuda"), "cpu_fallback_pct": self._driver_percentage("torch_cpu"), }
def _driver_percentage(self, driver_name: str) -> float: """Compute percentage of operations using specific driver""" if not self.logs["svd_driver"]: return 0.0 count = sum(1 for d in self.logs["svd_driver"] if d == driver_name) return 100.0 * count / len(self.logs["svd_driver"])
[docs] def global_error_bound(self) -> float: """Compute global error upper bound""" if not self.logs["eps_local"]: return 0.0 return math.sqrt(sum(e**2 for e in self.logs["eps_local"]))
[docs] def reset(self): """Clear all logs""" for key in self.logs: self.logs[key].clear()