Source code for atlas_q.truncation

"""
Adaptive Truncation Policy for MPS

Implements energy-based rank selection with per-bond caps and budget enforcement.

Mathematical foundation:
- Keep smallest k such that Σ_{i≤k} σ_i² ≥ (1-ε²) Σ_i σ_i²
- Local error: ε_local² = Σ_{i>k} σ_i²
- Entropy: S = -Σ_i p_i log(p_i) where p_i = σ_i² / Σ_j σ_j²

Author: ATLAS-Q Contributors
Date: October 2025
License: MIT
"""

from typing import Callable, Tuple

import torch


[docs] def choose_rank_from_sigma( S: torch.Tensor, eps_bond: float, chi_cap: int, budget_ok: Callable[[int], bool] = lambda k: True, ) -> Tuple[int, float, float, float]: """ Adaptive rank selection from singular values Args: S: Singular values (sorted descending) eps_bond: Energy tolerance (truncation threshold) chi_cap: Maximum allowed rank for this bond budget_ok: Function that checks if rank k is within memory budget Returns: k: Selected rank eps_local: Local truncation error entropy: Entanglement entropy at this bond condS: Condition number (σ_max / σ_k) Strategy: 1. Find k_tol from energy criterion: cumulative energy ≥ (1-ε²) * total 2. Cap by chi_max_per_bond 3. Reduce if budget violation 4. Compute diagnostics (entropy, condition number, error) """ if len(S) == 0: return 0, 0.0, 0.0, float("inf") # Energy criterion: keep singular values until (1-ε²) of energy retained E = (S * S).cumsum(0) total = E[-1] if total < 1e-30: # Degenerate case return 1, 0.0, 0.0, float("inf") thresh = (1.0 - eps_bond**2) * total k_tol = int(torch.searchsorted(E, thresh).item()) + 1 k_tol = min(k_tol, len(S)) # Can't exceed available singular values # Apply per-bond cap k = min(k_tol, chi_cap) # Apply budget constraint (greedy reduction) while k > 1 and not budget_ok(k): k -= 1 # Compute local truncation error eps_local = torch.sqrt(torch.clamp(total - E[k - 1], min=0.0)) # Compute entanglement entropy S_kept = S[:k] p = (S_kept * S_kept) / (S_kept * S_kept).sum() entropy = float(-(p * torch.log(torch.clamp(p, min=1e-30))).sum().item()) # Compute condition number if k > 1 and S[k - 1] > 0: condS = float((S[0] / S[k - 1]).item()) else: condS = float("inf") return k, float(eps_local.item()), entropy, condS
[docs] def compute_global_error_bound(local_errors: list) -> float: """ Compute global error bound from local truncation errors Using simple Frobenius norm bound: ε_global ≤ sqrt(Σ_b ε²_local,b) Args: local_errors: List of local truncation errors from each bond Returns: Upper bound on global state error """ import math return math.sqrt(sum(e**2 for e in local_errors))
[docs] def check_entropy_sanity(entropy: float, chi_left: int, chi_right: int) -> bool: """ Verify entropy is within physical bounds For a bond with dimensions χ_L and χ_R, maximum entropy is: S_max = log₂(min(χ_L · 2, 2 · χ_R)) Args: entropy: Measured entanglement entropy chi_left: Left bond dimension chi_right: Right bond dimension Returns: True if entropy is physically reasonable """ import math max_entropy = math.log2(min(chi_left * 2, 2 * chi_right)) return entropy <= max_entropy + 1e-6 # Allow small numerical error
# ============================================================================= # IR-Enhanced Truncation (New in v0.7.0) # ============================================================================= # IR constants E2_THRESHOLD = 0.135 # e^-2 threshold for observability
[docs] def analyze_truncation_regime( S: torch.Tensor, site_index: int = 0, decoherence_rate: float = 0.05, initial_coherence: float = 1.0, ): """ Analyze regime BEFORE truncation decision (IR-correct approach). This implements the key IR insight: regime determines representation, not the other way around. We diagnose first, then choose truncation. Args: S: Singular values (sorted descending) site_index: Distance from preparation (for L5 decoherence) decoherence_rate: Decay rate α in R̄(D) = R₀ e^(-αD) initial_coherence: Initial coherence R₀ Returns: RegimeAnalysis from ir_enhanced.regime_analyzer """ from .ir_enhanced.regime_analyzer import analyze_mps_bond_regime return analyze_mps_bond_regime( singular_values=S.cpu().numpy() if hasattr(S, 'cpu') else S, site_index=site_index, decoherence_rate=decoherence_rate, initial_coherence=initial_coherence, )
[docs] def choose_rank_with_regime( S: torch.Tensor, eps_bond: float, chi_cap: int, site_index: int = 0, decoherence_rate: float = 0.05, initial_coherence: float = 1.0, budget_ok: Callable[[int], bool] = lambda k: True, ) -> Tuple[int, float, float, float, "RegimeAnalysis"]: """ IR-correct rank selection: Diagnose regime FIRST, then truncate. This is the proper IR approach: 1. Analyze regime using L5 decoherence model 2. Select epsilon based on regime (not the other way around) 3. Truncate with regime-appropriate tolerance Args: S: Singular values (sorted descending) eps_bond: Base energy tolerance chi_cap: Maximum allowed rank site_index: Distance from preparation site decoherence_rate: Decay rate α initial_coherence: Initial coherence R₀ budget_ok: Budget constraint function Returns: (k, eps_local, entropy, condS, regime_analysis) """ from .ir_enhanced.regime_analyzer import ObservabilityRegime # Step 1: Diagnose regime BEFORE truncation regime_analysis = analyze_truncation_regime( S, site_index, decoherence_rate, initial_coherence ) # Step 2: Select epsilon based on regime if regime_analysis.regime == ObservabilityRegime.IR: # Structure observable - preserve it with tight tolerance effective_eps = eps_bond * 0.5 elif regime_analysis.regime == ObservabilityRegime.TRANSITION: # Near boundary - standard tolerance effective_eps = eps_bond else: # AIR regime # Structure hidden - aggressive truncation is safe effective_eps = eps_bond * 2.0 # Step 3: Standard energy-based truncation with regime-adjusted eps if len(S) == 0: return 0, 0.0, 0.0, float("inf"), regime_analysis E = (S * S).cumsum(0) total = E[-1] if total < 1e-30: return 1, 0.0, 0.0, float("inf"), regime_analysis thresh = (1.0 - effective_eps**2) * total k_tol = int(torch.searchsorted(E, thresh).item()) + 1 k_tol = min(k_tol, len(S)) k = min(k_tol, chi_cap) while k > 1 and not budget_ok(k): k -= 1 # Compute diagnostics eps_local = torch.sqrt(torch.clamp(total - E[k - 1], min=0.0)) S_kept = S[:k] p = (S_kept * S_kept) / (S_kept * S_kept).sum() entropy = float(-(p * torch.log(torch.clamp(p, min=1e-30))).sum().item()) if k > 1 and S[k - 1] > 0: condS = float((S[0] / S[k - 1]).item()) else: condS = float("inf") return k, float(eps_local.item()), entropy, condS, regime_analysis
[docs] def choose_rank_with_coherence( S: torch.Tensor, eps_bond: float, chi_cap: int, coherence: float, budget_ok: Callable[[int], bool] = lambda k: True, e2_threshold: float = E2_THRESHOLD, ) -> Tuple[int, float, float, float, float]: """ IR-enhanced rank selection incorporating coherence metrics. Implements IR Law L5 (Exponential Decoherence) insight: - High coherence (R̄ > e^-2): Structure is observable, preserve it - Low coherence (R̄ < e^-2): Structure is hidden, aggressive truncation OK Args: S: Singular values (sorted descending) eps_bond: Energy tolerance (truncation threshold) chi_cap: Maximum allowed rank for this bond coherence: Current coherence metric R̄ (0-1) budget_ok: Function that checks if rank k is within memory budget e2_threshold: e^-2 threshold (default: 0.135) Returns: k: Selected rank eps_local: Local truncation error entropy: Entanglement entropy at this bond condS: Condition number (σ_max / σ_k) coherence_adjusted: Whether coherence influenced the decision Strategy: 1. Start with standard energy criterion 2. If coherence > e^-2: be conservative (preserve structure) 3. If coherence < e^-2: allow aggressive truncation (structure hidden anyway) 4. Apply caps and budget constraints """ if len(S) == 0: return 0, 0.0, 0.0, float("inf"), 0.0 # Energy criterion: keep singular values until (1-ε²) of energy retained E = (S * S).cumsum(0) total = E[-1] if total < 1e-30: # Degenerate case return 1, 0.0, 0.0, float("inf"), 0.0 # Adjust epsilon based on coherence regime if coherence > e2_threshold: # IR regime: structure observable, be conservative # Use tighter tolerance to preserve coherent structure effective_eps = eps_bond * 0.5 coherence_adjusted = 1.0 elif coherence > e2_threshold * 0.5: # Transition regime: moderate caution effective_eps = eps_bond coherence_adjusted = 0.5 else: # AIR regime: structure hidden, aggressive truncation OK # Use looser tolerance (saves resources without losing information) effective_eps = eps_bond * 2.0 coherence_adjusted = 0.0 thresh = (1.0 - effective_eps**2) * total k_tol = int(torch.searchsorted(E, thresh).item()) + 1 k_tol = min(k_tol, len(S)) # Apply per-bond cap k = min(k_tol, chi_cap) # Apply budget constraint while k > 1 and not budget_ok(k): k -= 1 # Compute diagnostics eps_local = torch.sqrt(torch.clamp(total - E[k - 1], min=0.0)) S_kept = S[:k] p = (S_kept * S_kept) / (S_kept * S_kept).sum() entropy = float(-(p * torch.log(torch.clamp(p, min=1e-30))).sum().item()) if k > 1 and S[k - 1] > 0: condS = float((S[0] / S[k - 1]).item()) else: condS = float("inf") return k, float(eps_local.item()), entropy, condS, coherence_adjusted
[docs] def choose_rank_with_decoherence( S: torch.Tensor, eps_bond: float, chi_cap: int, decoherence_rate: float, site_index: int, initial_coherence: float = 1.0, budget_ok: Callable[[int], bool] = lambda k: True, e2_threshold: float = E2_THRESHOLD, ) -> Tuple[int, float, float, float, float]: """ IR L5-based rank selection with exponential decoherence model. Implements IR Law L5: R̄(D) = R₀ e^(-αD) Coherence decays exponentially under perturbation/distance. Args: S: Singular values (sorted descending) eps_bond: Energy tolerance (truncation threshold) chi_cap: Maximum allowed rank for this bond decoherence_rate: Decay rate α (higher = faster decay) site_index: Distance from preparation site D initial_coherence: Initial coherence R₀ (default: 1.0) budget_ok: Function that checks if rank k is within memory budget e2_threshold: e^-2 threshold (default: 0.135) Returns: k: Selected rank eps_local: Local truncation error entropy: Entanglement entropy condS: Condition number estimated_coherence: Estimated coherence at this site References: - IR Paper Section 4.5: "L5: Exponential Decoherence Law" - Formula: R̄(D) = R₀ e^(-αD) """ import math # Estimate coherence at this site using L5 estimated_coherence = initial_coherence * math.exp(-decoherence_rate * site_index) # Use coherence-aware truncation return choose_rank_with_coherence( S=S, eps_bond=eps_bond, chi_cap=chi_cap, coherence=estimated_coherence, budget_ok=budget_ok, e2_threshold=e2_threshold )
[docs] def compute_spectral_coherence(S: torch.Tensor) -> float: """ Compute spectral coherence from singular value distribution. High spectral coherence = power concentrated in dominant modes Low spectral coherence = power spread across many modes This can be used as a proxy for response field coherence when direct amplitude access is not available. Args: S: Singular values (sorted descending) Returns: Spectral coherence R̄ ∈ [0, 1] """ if len(S) == 0: return 0.0 # Normalize singular values to probabilities S2 = S * S total = S2.sum() if total < 1e-30: return 0.0 p = S2 / total # Coherence = concentration in dominant mode # For perfectly coherent state: p = [1, 0, 0, ...] → R̄ = 1 # For maximally mixed: p = [1/n, 1/n, ...] → R̄ = 1/n → 0 for large n coherence = float(p[0].item()) return coherence
[docs] def adaptive_chi_from_coherence( coherence_history: list, current_chi: int, chi_min: int = 2, chi_max: int = 256, e2_threshold: float = E2_THRESHOLD, growth_rate: float = 1.2, shrink_rate: float = 0.8, ) -> int: """ Four-layer hierarchy feedback: Observable → Structure adaptation. Adjusts bond dimension based on measured coherence to optimize the representation cost according to IR principles. IR insight: - Low coherence at observable layer → representation cost too high - Reduce χ to save resources (structure is hidden anyway) - High coherence → structure is valuable, allow χ growth Args: coherence_history: Recent coherence measurements (R̄ values) current_chi: Current bond dimension chi_min: Minimum allowed bond dimension chi_max: Maximum allowed bond dimension e2_threshold: e^-2 threshold growth_rate: Multiplier for chi growth (default: 1.2) shrink_rate: Multiplier for chi shrinkage (default: 0.8) Returns: Adjusted bond dimension References: - IR Paper Section 2: "The Four-Layer Framework" - Observable layer feeds back to Structure layer """ if not coherence_history: return current_chi # Use recent average coherence avg_coherence = sum(coherence_history[-5:]) / len(coherence_history[-5:]) if avg_coherence < e2_threshold: # AIR regime: structure hidden, reduce resources new_chi = int(current_chi * shrink_rate) elif avg_coherence > 0.5: # High coherence: structure valuable, allow growth new_chi = int(current_chi * growth_rate) else: # Moderate coherence: maintain current new_chi = current_chi # Apply bounds new_chi = max(chi_min, min(chi_max, new_chi)) return new_chi
[docs] def coherence_aware_truncation_policy( site_coherences: list, base_eps: float = 1e-6, base_chi: int = 64, e2_threshold: float = E2_THRESHOLD, ) -> Tuple[list, list]: """ Generate per-site truncation policy based on coherence distribution. Sites with high coherence get tighter truncation (preserve structure). Sites with low coherence get looser truncation (save resources). Args: site_coherences: Coherence R̄ at each site base_eps: Base truncation tolerance base_chi: Base bond dimension e2_threshold: e^-2 threshold Returns: eps_per_site: Truncation tolerance for each site chi_per_site: Bond dimension cap for each site """ n_sites = len(site_coherences) eps_per_site = [] chi_per_site = [] for i, coherence in enumerate(site_coherences): if coherence > e2_threshold * 2: # High coherence: tight truncation, high chi eps_per_site.append(base_eps * 0.5) chi_per_site.append(int(base_chi * 1.5)) elif coherence > e2_threshold: # Above threshold: standard eps_per_site.append(base_eps) chi_per_site.append(base_chi) elif coherence > e2_threshold * 0.5: # Near threshold: slightly loose eps_per_site.append(base_eps * 1.5) chi_per_site.append(int(base_chi * 0.8)) else: # Low coherence: aggressive truncation eps_per_site.append(base_eps * 2.0) chi_per_site.append(int(base_chi * 0.5)) return eps_per_site, chi_per_site