"""
Module constaining the :py:class:`Qsyst` class. An instance contains the drive, detuning, coupling and dissipation parameters of a M mode system.
The :py:class:`Qsyst` methods will then be used to compute the displacement and covariance matrices of the system after time evolutions under input encoding.
"""
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
from math import prod, factorial
import matplotlib.pyplot as plt
from gausstorch.utils.operations import (
torch_block,
cholesky_inverse_det,
truncate_alpha,
truncate_sigma,
)
from gausstorch.utils.display import (
setup_tex,
plot_evolution_N,
plot_evolution_fock,
fock_states_to_str_list,
)
from gausstorch.utils.bcolors import bcolors
from gausstorch.utils.loop_hafnian_torch import loop_hafnian
torch.set_default_dtype(torch.float64)
SYST_VARS_KEYS_WITHOUT_BIASES = [
"detuning",
"eA_real",
"eA_imag",
"g_real",
"g_imag",
"gs_real",
"gs_imag",
"k_int",
"k_ext",
]
SYST_VAR_BIAS_KEYS = [
"W_0",
"W_bias",
"theta_bias_real",
"theta_bias_imag",
"phi_0",
"phi_bias",
]
SYST_VARS_KEYS_WITH_BIASES = SYST_VARS_KEYS_WITHOUT_BIASES + SYST_VAR_BIAS_KEYS
[docs]
def init_pars_default(M: int) -> dict:
"""
Args:
M (int): Number of modes.
Returns:
dict : Default parameter values for M modes.
"""
g_shape = M * (M - 1) // 2
eA = M * 1e5
output = {
"M": M,
# learnable parameters
"W_0": torch.ones(1),
"W_bias": torch.zeros(1),
"theta_bias_real": torch.zeros(1),
"theta_bias_imag": torch.zeros(1),
"phi_0": torch.zeros(1),
"phi_bias": torch.zeros(1),
"detuning": torch.zeros(M),
"eA_real": eA.real * torch.ones(M),
"eA_imag": eA.imag * torch.ones(M),
"g_real": 2 * torch.pi * 100e6 * torch.ones(g_shape),
"g_imag": torch.zeros(g_shape),
"gs_real": 2 * torch.pi * 20e6 * torch.ones(g_shape),
"gs_imag": torch.zeros(g_shape),
"k_int": 0 * torch.ones(M),
"k_ext": 2 * torch.pi * 2e6 * torch.ones(M),
# other parameters
"t_i": torch.tensor(1e-7),
}
return output
[docs]
class Qsyst(nn.Module):
"""Class allowing for the simulation of coupled gaussian modes."""
def __init__(
self, init_pars: dict, learnable_vars: list = [], init_print: bool = False
):
"""
Args:
init_pars (dict): Dict containing drive, detuning, coupling, dissipation parameters, and time interval. See the :py:func:`init_pars_default` function for a template.
learnable_vars (list, optional): List containing the parameters to learn (with autograd enabled). By default, no parameter has gradient computation enabled.
init_print (bool, optional): Bool checking whether to print a message during instance creation or not. Defaults to False.
Attributes:
init_pars (dict): Deep copy of `init_pars`.
learnable_vars (dict): Deep copy of `learnable_vars`.
syst_vars (torch.nn.ParameterDict): Possibly trainable parameters of the system.
other_pars (dict): Parameters which are never trainable.
M (int): Number of modes.
g_shape (int): Number of coupling combinations.
alpha0 (torch.Tensor): Displacement vector of the vacuum.
sigma0 (torch.Tensor): Covariance matrix of the vacuum
"""
super().__init__()
self.init_pars = deepcopy(init_pars)
self.learnable_vars = deepcopy(learnable_vars)
if init_print:
print(f"{bcolors.BOLD}Initializing Qsyst model{bcolors.ENDC}")
# System parameters are store in the OrderedDict syst_vars
# Hyper-parameters like the evolution time t_i are stored in other_pars.
# deepcopy init_par to prevent memory issues due to dictionaries.
self.syst_vars = deepcopy(self.make_syst_vars())
self.other_pars = deepcopy(self.make_other_pars())
# set no_grad for non-learnable parameters
for key, value in self.syst_vars.items():
if key not in learnable_vars:
self.syst_vars.update(
{key: nn.Parameter(value.data, requires_grad=False)}
)
# useful constants to keep
self.M = self.init_pars["M"]
self.g_shape = self.M * (self.M - 1) // 2
self.alpha0 = torch.zeros((2 * self.M, 1), dtype=torch.complex128)
self.sigma0 = (1 / 2) * torch.eye(2 * self.M, dtype=torch.complex128)
[docs]
def make_syst_vars(self) -> torch.nn.ParameterDict:
"""
Returns:
nn.ParameterDict: Parameters to learn in key-value pairs.
"""
syst_vars = nn.ParameterDict(
{
key: nn.Parameter(self.init_pars[key], requires_grad=True)
for key in SYST_VARS_KEYS_WITH_BIASES
}
)
return syst_vars
[docs]
def make_other_pars(self) -> dict:
"""
Returns:
dict : Dict containing parameters which can never be learned.
Note:
Only includes the time interval `t_i` at the moment.
"""
other_pars = {key: self.init_pars[key] for key in ["t_i"]}
return other_pars
[docs]
def create_coupling_matrices(
self,
detuning: torch.Tensor,
g_cplx: torch.Tensor,
gs_cplx: torch.Tensor,
) -> torch.Tensor:
"""
Args:
detuning (torch.Tensor): 1D tensor containing the detunings of the nearly resonant drives
g_cplx (torch.Tensor): 1D tensor containing the photon conversion rates
gs_cplx (torch.Tensor): 1D tensor containing the two-mode squeezing rates
Returns:
torch.Tensor: Coupling matrix (defined in my PhD thesis at equation 5.46)
"""
G = torch.diag(detuning).type(torch.complex128)
i, j = 0, 1
for g_ in g_cplx:
G[i, j] = g_
G[j, i] = torch.conj(g_)
j += 1
if j == self.M:
i += 1
j = i + 1
Gs = torch.zeros((self.M, self.M), dtype=torch.complex128)
i, j = 0, 1
for gs_ in gs_cplx:
Gs[i, j] = gs_ / 2
Gs[j, i] = gs_ / 2
j += 1
if j == self.M:
i += 1
j = i + 1
L0 = -1j * torch_block(G, 2 * Gs, -2 * Gs.conj().t(), -G.t())
return L0
[docs]
def return_theta_xmask(self, theta_key: str) -> tuple:
"""
Args:
theta_key (str): Key of encoding parameter theta
Returns:
tuple: `theta_0` the syst.syst_vars parameter associated to theta_key, and `xmask_shape` the shape of the encoding parameter, for the later linear encoding with xmask
"""
if theta_key == "eA":
theta_0 = self.syst_vars.eA_real + 1j * self.syst_vars.eA_imag
xmask_shape = self.M
elif theta_key == "detuning":
theta_0 = self.syst_vars.detuning
xmask_shape = self.M
elif theta_key == "g":
theta_0 = self.syst_vars.g_real + 1j * self.syst_vars.g_imag
xmask_shape = self.g_shape
elif theta_key == "gs":
theta_0 = self.syst_vars.gs_real + 1j * self.syst_vars.gs_imag
xmask_shape = self.g_shape
else:
raise AssertionError
return theta_0, xmask_shape
[docs]
def encode_theta(
self,
theta_0: torch.torch.Tensor,
x_mask: torch.Tensor,
x_min: torch.Tensor,
x_max: torch.Tensor,
encode_phase: bool,
) -> torch.Tensor:
"""
Args:
theta_0 (torch.Tensor): Encoding parameter value, before encoding the inputs x_mask
x_mask (torch.Tensor): Input values, in a 1-D tensor of the same shape as theta_0
x_min (torch.Tensor): Minimum input value (used to rescale the input between x_min and x_max)
x_max (torch.Tensor): Maximum input value (used to rescale the input between x_min and x_max)
encode_phase (bool): Bool deciding whether to encode into the absolute value or the phase of the encoding parameter (equations 5.105 and 5.106 in my thesis)
Returns:
torch.Tensor: Encoded parameter value
"""
theta_bias = (
self.syst_vars.theta_bias_real + 1j * self.syst_vars.theta_bias_imag
)
# normalize x_mask values between 0 and 1
x_mask = (x_mask.clone() - x_min) / (x_max - x_min)
if encode_phase:
# phase encoded between 0 and pi phase
theta_encoded = theta_bias + torch.mul(
theta_0.clone(),
torch.exp(
1j * (self.syst_vars.phi_0 * x_mask + self.syst_vars.phi_bias)
),
) # should have pi in phase
else: # encode in the amplitude
theta_encoded = torch.mul(theta_0.clone(), x_mask) + theta_bias
return theta_encoded
[docs]
def alpha_sigma_evolution_part_1(
self, theta_key: str, theta_encoded: torch.Tensor
) -> tuple:
"""Diagonalize the coupling matrix with dissipations.
Args:
theta_key (str): Key of the encoding parameters.
theta_encoded (torch.Tensor): Value of the encoded parameter.
Returns:
tuple: `lambda_F, U, Uinv, K_int, K_ext`, the eigenvalue decomposition and dissipation matrices
"""
detuning = self.syst_vars.detuning
g_cplx = self.syst_vars.g_real + 1j * self.syst_vars.g_imag
gs_cplx = self.syst_vars.gs_real + 1j * self.syst_vars.gs_imag
# The encoding parameter takes the place of the Qsyst parameter.
if theta_key == "g":
g_cplx = theta_encoded
elif theta_key == "gs":
gs_cplx = theta_encoded
L0 = self.create_coupling_matrices(
detuning=detuning, g_cplx=g_cplx, gs_cplx=gs_cplx
)
K_ext = torch.diag(
torch.cat((self.syst_vars.k_ext, self.syst_vars.k_ext), dim=0)
).type(torch.complex128)
K_int = torch.diag(
torch.cat((self.syst_vars.k_int, self.syst_vars.k_int), dim=0)
).type(torch.complex128)
K = K_int + K_ext
F_ = L0 - (K / 2) # F' in my thesis
# trick for faster integral computation. Backward only works if eigenvalues are real
lambda_F, U = torch.linalg.eig(F_)
Uinv = torch.inverse(U)
return lambda_F, U, Uinv, K_int, K_ext
[docs]
def alpha_sigma_evolution_part_2(
self,
theta_key: str,
theta_encoded: torch.Tensor,
t: torch.Tensor,
alpha_i: torch.Tensor,
sigma_i: torch.Tensor,
lambda_F: torch.Tensor,
U: torch.Tensor,
Uinv: torch.Tensor,
K_int: torch.Tensor,
K_ext: torch.Tensor,
) -> tuple:
"""
Computes the displacement (alpha) and covariance matrix (sigma) using the coupling matrix eigenvalue decomposition, and the evolution time t.
Args:
theta_key (str): Key of the encoding parameter.
theta_encoded (torch.Tensor): Encoded parameter value.
t (torch.Tensor): Duration of the time evolution.
alpha_i (torch.Tensor): Displacement value.
sigma_i (torch.Tensor): Covariance matrix.
lambda_F (torch.Tensor): Coupling matrix eigenvalue
U (torch.Tensor): Coupling matrix eigenvectors (each column is an eigenvector)
Uinv (torch.Tensor): Inverse of `U`
K_int (torch.Tensor): Diagonal matrix containing internal dissipations
K_ext (torch.Tensor): Diagonal matrix containing external dissipations
Returns:
tuple: Displacement vector and Covariance matrix
"""
M = self.M
eA_cplx = self.syst_vars.eA_real + 1j * self.syst_vars.eA_imag
if theta_key == "eA":
eA_cplx = theta_encoded
eA_cplx = eA_cplx.view(M, 1)
# minus sign to be the same as cascaded formalism. Otherwise, can put plus sign. All in all not important
A_in = -torch.vstack((eA_cplx, eA_cplx.conj())).to(torch.complex128)
sigma0 = (1 / 2) * torch.eye(2 * M, dtype=torch.complex128)
sigma_i = sigma_i.type(torch.complex128)
K = K_int + K_ext
# Now calculation involving t
F_t = (
U @ torch.diag(torch.exp(lambda_F * t)) @ Uinv
) # F_t = torch.matrix_exp(F_ * t)
I1 = torch.diag((1 / lambda_F) * (-1 + torch.exp(lambda_F * t)))
alpha_output = F_t @ alpha_i + U @ I1 @ Uinv @ torch.sqrt(K_ext) @ A_in
# I2 is the integral over [0,t] of exp((L-kappa/2)*tau) @ exp((L-kappa/2)*tau).T
#
# non vectorized computation of I2, but easier to understand:
P = Uinv @ K @ Uinv.conj().t()
I2 = torch.zeros((2 * M, 2 * M), dtype=torch.complex128)
for i in range(2 * M):
for j in range(2 * M):
lambda_F_sum = lambda_F[i] + lambda_F[j].conj()
I2[i, j] = P[i, j] * (torch.exp(lambda_F_sum * t) - 1) / lambda_F_sum
#
# vectorized computation of I2:
# def sum_conj(a, b):
# def sum_b(a0):
# return (a0 * torch.ones_like(b)) + b.conj()
#
# batched_sum_b = torch.func.vmap(sum_b)
# return batched_sum_b(a)
#
# lambda_F_sum = sum_conj(lambda_F, lambda_F)
# P = Uinv @ K @ Uinv.conj().t()
# I2 = P * (torch.exp(lambda_F_sum * t) - 1) / lambda_F_sum
sigma_output = (
F_t @ sigma_i @ F_t.conj().t() + sigma0 @ U @ I2 @ U.conj().t()
) # !!! covariance matrix can have complex values
return alpha_output, sigma_output
[docs]
def alpha_sigma_evolution(
self,
t: torch.Tensor,
alpha_i: torch.Tensor,
sigma_i: torch.Tensor,
theta_key: str,
theta_encoded: torch.Tensor,
) -> tuple:
"""Computes a gaussian state displacement and covariance matrix after a time evolution
Args:
t: duration of evolution from gaussian state (d_i, sigma_i) state to measurement
alpha_i: initial displacement
sigma_i: initial covariance
theta_key: key of the encoding parameter
theta_encoded: value of the encoding parameter
Returns:
tuple: new gaussian state (alpha_output, sigma_output)
Args:
t (torch.Tensor): Duration of evolution from gaussian state (d_i, sigma_i) state to measurement.
alpha_i (torch.Tensor): Initial displacement vector.
sigma_i (torch.Tensor): Initial covariance matrix.
theta_key (str): Name of the encoding parameter
theta_encoded (torch.Tensor): Value of the encoded parameter.
Returns:
tuple: New gaussian state `(alpha_output, sigma_output)`.
"""
# preliminary matrix computations
lambda_F, U, Uinv, K_int, K_ext = self.alpha_sigma_evolution_part_1(
theta_key=theta_key, theta_encoded=theta_encoded
)
alpha_output, sigma_output = self.alpha_sigma_evolution_part_2(
theta_key=theta_key,
theta_encoded=theta_encoded,
t=t,
alpha_i=alpha_i,
sigma_i=sigma_i,
lambda_F=lambda_F,
U=U,
Uinv=Uinv,
K_int=K_int,
K_ext=K_ext,
)
return alpha_output, sigma_output
[docs]
@staticmethod
def prob_with_shots(prob: torch.Tensor, n_shots: int) -> torch.Tensor:
"""Computes a probability for a given number of measurement shots with the binomial law
Args:
prob (torch.Tensor): Probability of a Gaussian Boson Sampling (GBS) probability
n_shots (int): Number of measurement shots
Returns:
torch.Tensor: GBS probability with shot measurement noise
"""
new_prob = prob + torch.randn(1) * torch.sqrt(prob * (1 - prob) / n_shots)
return new_prob
[docs]
def prob_gbs(
self, alpha: torch.Tensor, sigma: torch.Tensor, n: list, n_shots: int = None
) -> torch.Tensor:
"""This function uses the global GBS formula to calculate a Fock state occupation probability from the field operator displacement vector `alpha` and covariance matrix `sigma`.
Args:
alpha (torch.Tensor): Field operator displacement vector
sigma (torch.Tensor): Field operator covariance matrix
n (list): List containing the number of photons to measure in each mode
n_shots (int, optional): Number of measurement shots. Defaults to None, in which case perfectly accurate estimations are considered.
Returns:
torch.Tensor: Fock state occupation probability
"""
nn2 = n + n
M = self.M
id1 = torch.eye(M)
id2 = torch.eye(2 * M)
# block torch tensor
X = torch_block(torch.zeros(M, M), id1, id1, torch.zeros(M, M)).type(
torch.complex128
)
sigmaQ = sigma + 0.5 * id2
sigmaQ_inv, sigmaQ_det = cholesky_inverse_det(sigmaQ)
O = (id2 - sigmaQ_inv).type(torch.complex128)
A = X @ O
sigmaQ_inv = sigmaQ_inv.type(torch.complex128)
gamma = (alpha.conj().t() @ sigmaQ_inv).squeeze()
lhaf_A = loop_hafnian(A, D=gamma, reps=nn2)
# print(f'dtype of lhaf_A: {lhaf_A.dtype}')
result = (
lhaf_A
* torch.exp(-0.5 * alpha.conj().T @ sigmaQ_inv @ alpha)
/ (torch.sqrt(sigmaQ_det) * prod([factorial(ni) for ni in n]))
)
# print(f'dtype of result: {result.squeeze().real.dtype}')
result = result.real
if n_shots is not None:
result = self.prob_with_shots(result, n_shots)
return result.squeeze()
[docs]
def prob_gbs_partial_trace(
self,
alpha: torch.Tensor,
sigma: torch.Tensor,
n: list,
modes_kept: list,
n_shots: int = None,
) -> torch.Tensor:
"""This function uses the GBS formula to calculate P(n) from the 1st and 2nd moments.
All the modes except the ones contained in modes_kept are traced.
n is the list containing the photon combination to measure, after the partial trace.
Example:
If `M = 3`, `modes_kept = [0, 2]`, `n = [2, 4]`,
This means you trace out mode 1, and compute the probability of measuring 2 photons in mode 0, and 4 photons in mode 2.
Args:
alpha (torch.Tensor): Field operator displacement vector.
sigma (torch.Tensor): Field operator covariance matrix.
n (list): List containing the number of photons to measure in each considered mode
modes_kept (list): List of considered modes
n_shots (int, optional): Number of measurement shots. Defaults to None, in which case perfectly accurate estimations are considered.
Returns:
torch.Tensor: Fock state occupation probability
"""
sigma_new = truncate_sigma(sigma, modes_kept)
alpha_new = truncate_alpha(alpha, modes_kept)
nn2 = n + n
M = len(modes_kept)
id1 = torch.eye(M)
id2 = torch.eye(2 * M)
# block torch tensor
T = torch_block(torch.zeros(M, M), id1, id1, torch.zeros(M, M)).type(
torch.complex128
)
sigmaQ = sigma_new + 0.5 * id2
sigmaQ_inv, sigmaQ_det = cholesky_inverse_det(sigmaQ)
O = (id2 - sigmaQ_inv).type(torch.complex128)
A = T @ O
sigmaQ_inv = sigmaQ_inv.type(torch.complex128)
gamma = (alpha_new.conj().t() @ sigmaQ_inv).squeeze()
lhaf_A = loop_hafnian(A, D=gamma, reps=nn2)
result = (
lhaf_A
* torch.exp(-0.5 * alpha_new.conj().T @ sigmaQ_inv @ alpha_new)
/ (torch.sqrt(sigmaQ_det) * prod([factorial(ni) for ni in n]))
)
if n_shots is not None:
result = self.prob_with_shots(result, n_shots)
return result.squeeze().real
[docs]
def evolution_N(
self,
theta_key: str = "eA",
x_val: torch.Tensor = torch.tensor(1),
x_min: torch.Tensor = torch.tensor(0),
x_max: torch.Tensor = torch.tensor(1),
encode_phase: bool = False,
res: int = 1_000,
compute_plot: bool = True,
yscale: str = "linear",
show_plot: bool = True,
inference_mode: bool = True,
return_vals: bool = False,
return_tspan: bool = False,
):
"""Plots or computes the time evolution of the system during a time interval
Args:
x_val (torch.Tensor, optional): Input value. Defaults to torch.tensor(1).
x_min (torch.Tensor, optional): Minimum input value (for input rescaling). Defaults to torch.tensor(0).
x_max (torch.Tensor, optional): Maximum input value (for input rescaling). Defaults to torch.tensor(1).
encode_phase (bool, optional): If `True`, the input is encoded into the phase of the encoding parameter. Defaults to False.
res (int, optional): Number of points to compute. Defaults to 1_000.
compute_plot (bool, optional): If `True`, the plot is computed. Defaults to True.
yscale (str, optional): Choice of yscale. Defaults to 'linear'.
show_plot (bool, optional): If `True`, plot is shown with `plt.show()`. Defaults to True.
inference_mode (bool, optional): If `True`, no gradients are computed. Defaults to True.
return_vals (bool, optional): If `True`, return photon number values. Defaults to False.
return_tspan (bool, optional): If `True`, also return discrete time values. Defaults to False.
Returns:
Either None, or photon number values with discrete time values
"""
torch.set_num_threads(1)
torch.inference_mode(inference_mode)
# encode the input x into theta_0
theta_0, x_mask_shape = self.return_theta_xmask(theta_key)
x_mask = x_val * torch.ones(x_mask_shape, dtype=torch.complex128)
theta_encoded = self.encode_theta(
theta_0=theta_0,
x_mask=x_mask,
x_min=x_min,
x_max=x_max,
encode_phase=encode_phase,
)
tspan = torch.linspace(0, self.other_pars["t_i"], res)
means = torch.zeros((res, self.M))
# perform the eigenvalue decomposition only once, then use to compute alpha and sigma at all times t from tspan.
lambda_F, U, Uinv, K_int, K_ext = self.alpha_sigma_evolution_part_1(
theta_key=theta_key, theta_encoded=theta_encoded
)
for i, t in enumerate(tspan):
alpha_t, sigma_t = self.alpha_sigma_evolution_part_2(
theta_key=theta_key,
theta_encoded=theta_encoded,
t=t,
alpha_i=self.alpha0,
sigma_i=self.sigma0,
lambda_F=lambda_F,
U=U,
Uinv=Uinv,
K_int=K_int,
K_ext=K_ext,
)
for j in range(self.M):
means[i, j] = (
torch.real(sigma_t[j, j]) + (torch.abs(alpha_t[j]) ** 2) - 0.5
)
# plot means
tspan_renormalized = tspan * torch.mean(
self.syst_vars.k_ext
) # renorm by kappa average
if compute_plot:
tspan_np = tspan_renormalized.detach().numpy()
means_np = means.detach().numpy()
fig, ax = plot_evolution_N(
tspan_np,
means_np,
width_ratio=0.48,
xlabel=r"time $\times \kappa$",
yscale=yscale,
)
if show_plot:
# print(f"at time {tspan[-1]},\n"
# f"final alpha_t: \n{alpha_t}\n"
# f"final sigma_t: \n{sigma_t}")
plt.show()
torch.inference_mode(False)
if return_vals:
if return_tspan:
return tspan, means
else:
return means
[docs]
def evolution_fock(
self,
fock_combs_per_mode_comb: dict = None,
theta_key: str = "eA",
x_val: torch.Tensor = torch.tensor(1), # input, of value between 0 and 1
x_min: torch.Tensor = torch.tensor(0),
x_max: torch.Tensor = torch.tensor(1),
encode_phase: bool = False,
res: int = 1_000,
compute_plot: bool = True,
show_plot: bool = True,
inference_mode: bool = True,
return_vals: bool = False,
):
"""Plots or compute the Fock state occupation probabilities during a time interval
Args:
fock_combs_per_mode_comb (dict, optional): Fock state combinations to plot. Defaults to None.
theta_key (str, optional): Name of the encoding parameter. Defaults to 'eA'.
x_val (torch.Tensor, optional): Input value. Defaults to torch.tensor(1).
x_min (torch.Tensor, optional): Minimum input value (for input rescaling). Defaults to torch.tensor(0).
x_max (torch.Tensor, optional): Maximum input value (for input rescaling). Defaults to torch.tensor(1).
encode_phase (bool, optional): If `True`, input is encoded into the phase of the encoding parameter. Defaults to False.
res (int, optional): Number of discrete time steps. Defaults to 1_000.
compute_plot (bool, optional): If `True`, compute the plot. Defaults to True.
show_plot (bool, optional): If `True`, plot is shown with plt.show(). Defaults to True.
inference_mode (bool, optional): If `True`, gradient calculations are disabled. Defaults to True.
return_vals (bool, optional): If `True`, return Fock states values. Defaults to False.
Returns:
Either None, or torch.Tensor containing Fock state
"""
torch.set_num_threads(1)
torch.inference_mode(inference_mode)
num_probs = sum(
[
len(fock_combinations)
for fock_combinations in fock_combs_per_mode_comb.values()
]
)
# encode the input x into theta_0
theta_0, x_mask_shape = self.return_theta_xmask(theta_key)
x_mask = x_val * torch.ones(x_mask_shape, dtype=torch.complex128)
theta_encoded = self.encode_theta(
theta_0=theta_0,
x_mask=x_mask,
x_min=x_min,
x_max=x_max,
encode_phase=encode_phase,
)
tspan = torch.linspace(0, self.other_pars["t_i"], res)
probs = torch.zeros((res, num_probs))
# perform the eigenvalue decomposition only once, then use to compute alpha and sigma at all times t from tspan.
lambda_F, U, Uinv, K_int, K_ext = self.alpha_sigma_evolution_part_1(
theta_key=theta_key, theta_encoded=theta_encoded
)
for i, t in enumerate(tspan):
alpha_t, sigma_t = self.alpha_sigma_evolution_part_2(
theta_key=theta_key,
theta_encoded=theta_encoded,
t=t,
alpha_i=self.alpha0,
sigma_i=self.sigma0,
lambda_F=lambda_F,
U=U,
Uinv=Uinv,
K_int=K_int,
K_ext=K_ext,
)
prob_counter = 0
for mode_comb, fock_combinations in fock_combs_per_mode_comb.items():
for fock_combination in fock_combinations:
p = self.prob_gbs_partial_trace(
alpha_t, sigma_t, fock_combination, mode_comb
)
probs[i, prob_counter] = p
prob_counter += 1
torch.inference_mode(False)
# plot probs
if compute_plot:
prob_means = torch.mean(probs, dim=0).tolist()
tspan_renormalized = tspan * torch.mean(
self.syst_vars.k_ext
) # renorm by kappa average
tspan_np = tspan_renormalized.detach().numpy()
probs_np = probs.detach().numpy()
labels = fock_states_to_str_list(fock_combs_per_mode_comb)
# labels = [prob_notation + rf"\n avg$={avg:.4e}$" for prob_notation, avg in zip(labels, prob_means)]
labels = [prob_notation for prob_notation, avg in zip(labels, prob_means)]
if len(labels) == 1:
labels = labels[0]
fig, ax = plot_evolution_fock(
tspan=tspan_np,
probs=probs_np,
labels=labels,
width_ratio=0.48,
xlabel=r"time $\times \kappa$",
ylabel="Probability",
)
if show_plot:
plt.show()
if return_vals:
return probs