Source code for qm_sim.temporal_solver.crank_nicolson

from typing import Callable

import numpy as np
from scipy.linalg import solve_banded
from scipy.sparse import dia_matrix

from .base import TemporalSolver


[docs] def I_plus_aH(a: complex, H: dia_matrix) -> dia_matrix: """Return I + aH, where :code:`I`is the unit matrix and :code:`a` is constant""" out = H.copy().astype(np.complex128) out.data *= a zero_ind = list(out.offsets).index(0) out.data[zero_ind] += 1 return out
[docs] class CrankNicolson(TemporalSolver): order = 2 explicit = False stable = True name = "crank-nicolson"
[docs] def __init__( self, H: Callable[[float], np.ndarray], output_shape: tuple[int] = None ): TemporalSolver.__init__(self, H, output_shape) if len(self.output_shape) != 1: raise ValueError("Crank-Nicolson solver only supports 1D systems")
[docs] def iterate( self, v_0: np.ndarray, t0: float, t_final: float, dt: float, dt_storage: float = None, verbose: bool = True, ) -> tuple[np.ndarray, np.ndarray]: if dt_storage is None: dt_storage = dt H = self.H prefactor = dt / 2 tn = t0 psi_n = v_0 t = [tn] psi = [psi_n] with self.tqdm(t0, t_final, verbose) as pbar: while tn < t_final: # psi^n+1 = psi^n + dt/2*(F^n+1 + F^n) # F^n = 1/ihbar * H^n @ psi^n # => psi^n+1 = psi^n + dt/2ihbar*(H^n+1 @ psi^n+1 + H^n @ psi^n) # (I - dt/2ihbar*(H^n+1) @ psi^n+1 = (I + dt/2ihbar*H^n) @ psi^n lhs = I_plus_aH(-prefactor, H(tn + dt)) rhs = I_plus_aH(prefactor, H(tn)) @ psi_n # Reformulate the solve_banded function in terms of the dia_matrix class psi_n = solve_banded( (-lhs.offsets[0], lhs.offsets[-1]), lhs.data[::-1, :], rhs, ) pbar.update(dt) tn += dt # store data every :code:`dt_storage`seconds if tn // dt_storage > len(psi): psi.append(psi_n) t.append(tn) return np.array(t), np.array(psi)