Source code for gausstorch.utils.param_processing

import torch
import torch.nn as nn
from copy import deepcopy


[docs] def rescale_law( unscaled_val: torch.Tensor, par_key: str, R: torch.Tensor ) -> torch.Tensor: """Returns rescaled value of a physical parameter. Args: unscaled_val (torch.Tensor): unscaled value of parameter to set par_key (str): name of the parameter whose value to substitute R (torch.Tensor): rescaling factor Returns: torch.Tensor: Rescaled parameter """ scaled_val = unscaled_val if par_key in [ "detuning", "g_real", "g_imag", "gs_real", "gs_imag", "k_int", "k_ext", ]: scaled_val = unscaled_val / R elif par_key in ["eA_real", "eA_imag"]: scaled_val = unscaled_val / torch.sqrt(R) elif par_key == "t_i": scaled_val = unscaled_val * R return scaled_val
[docs] def unscale_law( rescaled_val: torch.Tensor, par_key: str, R: torch.Tensor ) -> torch.Tensor: """Performs the inverse operation to :py:func:`rescale_law` Args: rescaled_val (torch.Tensor): rescaled value of parameter to unscale par_key (str): name of the parameter whose value to substitute R (torch.Tensor): rescaling factor Raises: NameError: If the `par_key` parameter key is not valid Returns: torch.Tensor: Unscaled value """ if par_key in [ "detuning", "g_real", "g_imag", "gs_real", "gs_imag", "k_int", "k_ext", ]: scaled_val = rescaled_val * R elif par_key in ["eA_real", "eA_imag"]: scaled_val = rescaled_val * torch.sqrt(R) else: raise NameError return scaled_val
[docs] def rescale_pars(pars: dict, R: torch.tensor) -> dict: """Returns a dict with all rescaled parameters from the `pars` argument Args: pars (dict): Contains parameter key-value pairs R (torch.tensor): Scaling parameter Returns: dict: Dict with rescaled parameters """ rescaled_pars = {key: rescale_law(val, key, R) for key, val in pars.items()} return deepcopy(rescaled_pars) # avoid memory shenanigans