"""
This is an extra module that holds functions globally, such that we can exploit
multiprocessing effortlessly. Here, the main :py:meth:`fit()` function is defined.
"""
import numpy as np
from pickle import dump
from typing import TypedDict, Union
from nptyping import Float, NDArray, Shape
from metrics_as_scores.distribution.distribution import DistTransform, Dataset
from metrics_as_scores.distribution.fitting import Continuous_RVs, Discrete_RVs, Fitter, StatisticalTest, StatisticalTestJson
from scipy.stats._distn_infrastructure import rv_continuous, rv_discrete
from timeit import default_timer as timer
Continuous_RVs_dict: dict[str, type[rv_continuous]] = { x: y for (x, y) in zip(map(lambda rv: type(rv).__name__, Continuous_RVs), map(lambda rv: type(rv), Continuous_RVs)) }
"""
Dictionary of continuous random variables that are supported by `scipy.stats`.
Note this is a dictionary of types, rather than instances.
"""
Discrete_RVs_dict: dict[str, type[rv_discrete]] = { x: y for (x, y) in zip(map(lambda rv: type(rv).__name__, Discrete_RVs), map(lambda rv: type(rv), Discrete_RVs)) }
"""
Dictionary of discrete random variables that are supported by `scipy.stats`.
Note this is a dictionary of types, rather than instances.
"""
[docs]def get_data_tuple(
ds: Dataset,
qtype: str,
dist_transform: DistTransform,
continuous_transform: bool=True
) -> list[tuple[str, NDArray[Shape["*"], Float]]]:
"""
This method is part of the workflow for computing parametric fits.
For a specific type of quantity and transform, it creates datasets
for all available contexts.
ds: ``Dataset``
qtype: ``str``
The type of quantity to get datasets for.
dist_transform: ``DistTransform``
The chosen distribution transform.
continuous_transform: ``bool``
Whether the transform is real-valued or must be converted to integer.
:rtype: ``list[tuple[str, NDArray[Shape["*"], Float]]]``
:return: A list of tuples of three elements. The first element is a key
that identifies the context, the quantity type, and whether the data
was computed using unique values (see :meth:`Dataset.transform()`).
"""
l = []
for ctx in ds.contexts(include_all_contexts=True):
for unique_vals in [True, False]:
data = ds.data(qtype=qtype, context=(None if ctx == '__ALL__' else ctx), unique_vals=unique_vals, sub_sample=25_000)
transform_value, data = Dataset.transform(data=data, dist_transform=dist_transform, continuous_value=continuous_transform)
key = f"{ctx}_{qtype}{('_u' if unique_vals else '')}"
l.append((key, data, transform_value))
return l
[docs]class FitResult(TypedDict):
"""
This class is derived from :py:class:`TypedDict` and holds all properties
related to a single fit result, that is, a single specific configuration
that was fit to a 1-D array of data.
"""
grid_idx: int
dist_transform: str
transform_value: Union[float, None]
params: dict[str, Union[float, int]]
# Also, from row.to_dict():
context: str
qtype: str
rv: str
type: str
stat_tests: StatisticalTestJson
[docs]def fit(
ds: Dataset,
fitter_type: type[Fitter],
grid_idx: int, row,
dist_transform: DistTransform,
the_data: NDArray[Shape["*"], Float],
the_data_unique: NDArray[Shape["*"], Float],
transform_value: Union[float, None],
write_temporary_results: bool=False
) -> FitResult:
"""
This is the main stand-alone function that computes a single parametric fit to
a single 1-D array of data. This function is used in Parallel contexts and,
therefore, lives on module top level so it can be serialized.
ds: ``Dataset``
The data, needed for obtaining quantity types and contexts. Also passed forward to
:py:meth:`fit()`.
fitter_type: ``type[Fitter]``
The class for the fitter to use, either :py:class:`Fitter` or :py:class:`FitterPymoo`.
grid_idx: ``int``
This is only used so it can be stored in the :py:class:``FitResult``. This method
itself does not have access to the grid.
dist_transform: ``DistTransform``
The transform for which to generate parametric fits for. Later, we will save a single
file per transform, containing all related fits.
the_data: ``NDArray[Shape["*"], Float]``
The 1-D data used for fitting the RV.
the_data_unique: ``NDArray[Shape["*"], Float]``
1-D Array of data. In case of continuous data, it is the same as ``the_data``. In case
of discrete data, the data in this array contains a slight jitter as to make all data
points unique. Using this data is relevant for conducting statistical goodness of fit
tests.
:return:
The :py:class:``FitResult``. If the RV could not be fitted, then the parameters in
the fitting result will have a value of ``None``. This is so this method does not
throw exceptions. In case of a failure, no statistical tests are computed, either.
"""
start = timer()
import sys
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
qtype = row.qtype
fit_continuous = row.type == 'continuous'
RV: type[Union[rv_continuous, rv_discrete]] = None
if fit_continuous:
RV = Continuous_RVs_dict[row.rv]
else: # pragma: no cover
# We will not cover testing fitting of discrete RVs.
RV = Discrete_RVs_dict[row.rv]
unique_vals = ds.is_qtype_discrete(qtype=qtype) and fit_continuous
data = the_data_unique if unique_vals else the_data
ret_dict: FitResult = {}
ret_dict.update(row.to_dict())
ret_dict.update(dict(
grid_idx = grid_idx,
# Override with a string value.
dist_transform = dist_transform.name,
transform_value = transform_value,
params = None,
stat_tests = None))
try:
fitter = fitter_type(dist=RV)
params = fitter.fit(data=data, verbose=False)
params_tuple = tuple(params.values())
ret_dict.update(dict(params = params))
dist = RV()
temp_cdf = lambda x: dist.cdf(*(x, *params_tuple))
temp_ppf = lambda x: dist.ppf(*(x, *params_tuple))
data_st = data if not unique_vals else np.rint(data) # Remove jitter for test
st = StatisticalTest(data1=data_st, cdf=temp_cdf, ppf_or_data2=temp_ppf, max_samples=25_000)
ret_dict.update(dict(stat_tests = dict(st)))
except Exception: # pragma: no cover
# print(e)
# Do nothing at the moment, and allow returning a dict without params and stat_tests
pass
finally: # pragma: no cover
if write_temporary_results:
end = timer() - start
print(f'DONE! it took {format(end, "0>5.0f")} seconds ({row.type}), [{row.rv}]')
with open(f'./results/temp/{grid_idx}_{format(end, "0>5.0f")}', 'wb') as f:
dump(ret_dict, f)
return ret_dict