"""
This is an extra module that holds functions globally, such that we can exploit
multiprocessing effortlessly. Here, the main :py:meth:`fit()` function is defined.
"""
import numpy as np
from pickle import dump
from typing import TypedDict, Union
from nptyping import Float, NDArray, Shape
from metrics_as_scores.distribution.distribution import DistTransform, Dataset
from metrics_as_scores.distribution.fitting import Continuous_RVs, Discrete_RVs, Fitter, StatisticalTest, StatisticalTestJson
from scipy.stats._distn_infrastructure import rv_continuous, rv_discrete
from timeit import default_timer as timer
Continuous_RVs_dict: dict[str, type[rv_continuous]] = { x: y for (x, y) in zip(map(lambda rv: type(rv).__name__, Continuous_RVs), map(lambda rv: type(rv), Continuous_RVs)) }
"""
Dictionary of continuous random variables that are supported by `scipy.stats`.
Note this is a dictionary of types, rather than instances.
"""
Discrete_RVs_dict: dict[str, type[rv_discrete]] = { x: y for (x, y) in zip(map(lambda rv: type(rv).__name__, Discrete_RVs), map(lambda rv: type(rv), Discrete_RVs)) }
"""
Dictionary of discrete random variables that are supported by `scipy.stats`.
Note this is a dictionary of types, rather than instances.
"""
[docs]def get_data_tuple(
    ds: Dataset,
    qtype: str,
    dist_transform: DistTransform,
    continuous_transform: bool=True
) -> list[tuple[str, NDArray[Shape["*"], Float]]]:
    """
    This method is part of the workflow for computing parametric fits.
    For a specific type of quantity and transform, it creates datasets
    for all available contexts.
    ds: ``Dataset``
    qtype: ``str``
        The type of quantity to get datasets for.
    
    dist_transform: ``DistTransform``
        The chosen distribution transform.
    continuous_transform: ``bool``
        Whether the transform is real-valued or must be converted to integer.
    :rtype: ``list[tuple[str, NDArray[Shape["*"], Float]]]``
    :return: A list of tuples of three elements. The first element is a key
        that identifies the context, the quantity type, and whether the data
        was computed using unique values (see :meth:`Dataset.transform()`).
    """
    l = []
    for ctx in ds.contexts(include_all_contexts=True):
        for unique_vals in [True, False]:
            data = ds.data(qtype=qtype, context=(None if ctx == '__ALL__' else ctx), unique_vals=unique_vals, sub_sample=25_000)
            transform_value, data = Dataset.transform(data=data, dist_transform=dist_transform, continuous_value=continuous_transform)
            key = f"{ctx}_{qtype}{('_u' if unique_vals else '')}"
            l.append((key, data, transform_value))
    return l 
[docs]class FitResult(TypedDict):
    """
    This class is derived from :py:class:`TypedDict` and holds all properties
    related to a single fit result, that is, a single specific configuration
    that was fit to a 1-D array of data.
    """
    grid_idx: int
    dist_transform: str
    transform_value: Union[float, None]
    params: dict[str, Union[float, int]]
    # Also, from row.to_dict():
    context: str
    qtype: str
    rv: str
    type: str
    stat_tests: StatisticalTestJson 
[docs]def fit(
    ds: Dataset,
    fitter_type: type[Fitter],
    grid_idx: int, row,
    dist_transform: DistTransform,
    the_data: NDArray[Shape["*"], Float],
    the_data_unique: NDArray[Shape["*"], Float],
    transform_value: Union[float, None],
    write_temporary_results: bool=False
) -> FitResult:
    """
    This is the main stand-alone function that computes a single parametric fit to
    a single 1-D array of data. This function is used in Parallel contexts and,
    therefore, lives on module top level so it can be serialized.
    ds: ``Dataset``
        The data, needed for obtaining quantity types and contexts. Also passed forward to
        :py:meth:`fit()`.
    
    fitter_type: ``type[Fitter]``
        The class for the fitter to use, either :py:class:`Fitter` or :py:class:`FitterPymoo`.
    
    grid_idx: ``int``
        This is only used so it can be stored in the :py:class:``FitResult``. This method
        itself does not have access to the grid.
    
    dist_transform: ``DistTransform``
        The transform for which to generate parametric fits for. Later, we will save a single
        file per transform, containing all related fits.
    
    the_data: ``NDArray[Shape["*"], Float]``
        The 1-D data used for fitting the RV.
    
    the_data_unique: ``NDArray[Shape["*"], Float]``
        1-D Array of data. In case of continuous data, it is the same as ``the_data``. In case
        of discrete data, the data in this array contains a slight jitter as to make all data
        points unique. Using this data is relevant for conducting statistical goodness of fit
        tests.
    
    :return:
        The :py:class:``FitResult``. If the RV could not be fitted, then the parameters in
        the fitting result will have a value of ``None``. This is so this method does not
        throw exceptions. In case of a failure, no statistical tests are computed, either.
    """
    start = timer()
    import sys
    if not sys.warnoptions:
        import warnings
        warnings.simplefilter("ignore")
    qtype = row.qtype
    fit_continuous = row.type == 'continuous'
    RV: type[Union[rv_continuous, rv_discrete]] = None
    if fit_continuous:
        RV = Continuous_RVs_dict[row.rv]
    else: # pragma: no cover
        # We will not cover testing fitting of discrete RVs.
        RV = Discrete_RVs_dict[row.rv]
    unique_vals = ds.is_qtype_discrete(qtype=qtype) and fit_continuous
    data = the_data_unique if unique_vals else the_data
    ret_dict: FitResult = {}
    ret_dict.update(row.to_dict())
    ret_dict.update(dict(
        grid_idx = grid_idx,
        # Override with a string value.
        dist_transform = dist_transform.name,
        transform_value = transform_value,
        params = None,
        stat_tests = None))
    try:
        fitter = fitter_type(dist=RV)
        params = fitter.fit(data=data, verbose=False)
        params_tuple = tuple(params.values())
        ret_dict.update(dict(params = params))
        dist = RV()
        temp_cdf = lambda x: dist.cdf(*(x, *params_tuple))
        temp_ppf = lambda x: dist.ppf(*(x, *params_tuple))
        data_st = data if not unique_vals else np.rint(data) # Remove jitter for test
        st = StatisticalTest(data1=data_st, cdf=temp_cdf, ppf_or_data2=temp_ppf, max_samples=25_000)
        ret_dict.update(dict(stat_tests = dict(st)))
    except Exception: # pragma: no cover
        # print(e)
        # Do nothing at the moment, and allow returning a dict without params and stat_tests
        pass
    finally: # pragma: no cover
        if write_temporary_results:
            end = timer() - start
            print(f'DONE! it took {format(end, "0>5.0f")} seconds ({row.type}), [{row.rv}]')
            with open(f'./results/temp/{grid_idx}_{format(end, "0>5.0f")}', 'wb') as f:
                dump(ret_dict, f)
    return ret_dict