Source code for skwdro.base.samplers.torch.newsvendor_sampler
from typing import Optional, Union
import torch as pt
import skwdro.distributions as dst
from skwdro.base.samplers.torch.base_samplers import IsOptionalCovarianceSampler, NoLabelsSampler
[docs]
class NewsVendorNormalSampler(NoLabelsSampler, IsOptionalCovarianceSampler):
data_s: dst.MultivariateNormal
def __init__(
self,
xi,
*,
sigma: Optional[Union[pt.Tensor, float]] = None,
tril: Optional[pt.Tensor] = None,
prec: Optional[pt.Tensor] = None,
cov: Optional[pt.Tensor] = None,
seed: Optional[int] = None,
):
"""
Example of an available sampler for the Newsvendor problem.
- inputs are sampled from a gaussian distribution
Specify the parameters of the distributions as keywords arguments.
Parameters
----------
xi: pt.Tensor
mean for inputs
sigma: float|Tensor
scalar standard deviation shared through dimensions, for inputs.
See :py:class:`skwdro.base.samplers.torch.base_samplers.IsOptionalCovarianceSampler`
for other arguments.
"""
super(NewsVendorNormalSampler, self).__init__(
dst.MultivariateNormal(
loc=xi,
**self.init_covar(
xi.size(-1),
sigma, tril, prec, cov
) # type: ignore
),
seed
)
[docs]
def reset_mean(
self,
xi: pt.Tensor,
xi_labels: Optional[pt.Tensor]
):
del xi_labels
NewsVendorNormalSampler.__init__(
self,
xi,
seed=self.seed,
tril=self.data_s._unbroadcasted_scale_tril
)