Source code for qm_sim.plot

"""

Plots for the solutions to the Schrödinger equation

"""
from typing import Callable

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.animation import ArtistAnimation

from .nature_constants import e_0


def _get_plot_fun(ndim: int, ax: plt.Axes = None):
    if ax is None:
        ax = plt

    if ndim == 1:
        return lambda *args, **kwargs: ax.plot(*args, c="b", **kwargs)
    elif ndim == 2:
        # To preserve the return type of :code:`plot`above
        return lambda *args, **kwargs: [ax.imshow(*args, **kwargs)]
    elif ndim == 3:
        raise NotImplementedError("3D plots not yet supported")
    else:
        # It would be impressive if this is ever executed
        raise ValueError(f"Invalid system dimensionality: {ndim}D")


def _shape_from_int(n: int) -> tuple[int, int]:
    """Get plot shape from n (plot count)"""
    if n < 1:
        raise ValueError("Plot count must be positive")
    shapes = {
        1: (1, 1),
        2: (1, 2),
        3: (1, 3),
        4: (2, 2),
        5: (2, 3),
        6: (2, 3),
        7: (2, 4),
        8: (2, 4),
        9: (3, 3),
        10: (3, 4),
        11: (3, 4),
        12: (3, 4),
    }
    return shapes.get(n, (1, n))


[docs] def eigen(E: np.ndarray, psi: np.ndarray): """Plot absolute square of wave function :param E: Eigenenergies :type E: np.ndarray :param psi: Eigenfunctions :type psi: np.ndarray """ n = E.shape[0] ndim = len(psi.shape) - 1 shape = _shape_from_int(n) plt.figure() plt.suptitle("$|\\Psi|^2$") plot = _get_plot_fun(ndim) for i in range(n): plt.subplot(*shape, i + 1) plt.title(f"E{i} = {E[i] / e_0 :.3f} eV") plot(abs(psi[i]) ** 2) plt.tight_layout() plt.show() return
[docs] def temporal(t: np.ndarray, psi: np.ndarray, Vt: Callable[[float], np.ndarray]): """Create an animation of the wave function, alongside the potential :param t: Times correspondind to the wave functions :type t: np.ndarray :param psi: Wave funcions :type psi: np.ndarray :param Vt: Function of time, returning the potential at that time :type Vt: Callable[[float], np.ndarray] """ ndim = len(psi.shape) - 1 # Get potential at each timestep V = np.array([Vt(tn) for tn in t]) # We want to plot the probability distribution psi2 = np.abs(psi) ** 2 # Plot the results fig, (ax1, ax2) = plt.subplots(2, 1) ax1_plot = _get_plot_fun(ndim, ax1) ax2_plot = _get_plot_fun(ndim, ax2) (psi_plot,) = ax1_plot(psi2[0, :]) (V_plot,) = ax2_plot(V[0, :] / e_0) ax1.set_title(f"$|\\Psi|^2$") ax2.set_title("Potential [eV]") if ndim == 1: ax1.set_ylim(0, np.max(psi2) * 1.1) ax2.set_ylim(np.min(V / e_0), np.max(V / e_0)) elif ndim == 2: ax1.set_xticks([]) ax1.set_yticks([]) ax2.set_xticks([]) ax2.set_yticks([]) fig.tight_layout() ims = [(psi_plot, V_plot)] for n in range(psi2.shape[0]): (psi_plot,) = ax1_plot(psi2[n, ...], animated=True) (V_plot,) = ax2_plot(V[n, ...] / e_0, animated=True) ims.append( ( psi_plot, V_plot, ) ) ani = ArtistAnimation(fig, ims, blit=True, interval=50) plt.show() return
[docs] def potential(V: np.ndarray): """Plot the potential :param V: Potential array :type V: np.ndarray """ plt.figure() _get_plot_fun(len(V.shape))(V / e_0) plt.title("Potential [eV]") plt.show() return