Source code for xenonpy.utils.parameter_gen

#  Copyright (c) 2021. yoshida-lab. All rights reserved.
#  Use of this source code is governed by a BSD-style
#  license that can be found in the LICENSE file.

from collections import OrderedDict
from typing import Union, Dict, Callable, Sequence, Any, Optional

import numpy as np
import pandas as pd

__all__ = ['ParameterGenerator']


[docs]class ParameterGenerator(object): """ Generator for parameter set generating. """ def __init__(self, seed: Optional[int] = None, **kwargs: Union[Any, Sequence, Callable, Dict]): """ Parameters ---------- seed Numpy random seed. kwargs Parameter candidate. """ if len(kwargs) == 0: raise RuntimeError('need parameter candidate') np.random.seed(seed) self.tuples = OrderedDict() self.funcs = OrderedDict() self.dicts = OrderedDict() self.others = {} for k, v in kwargs.items(): if isinstance(v, (tuple, list, np.ndarray, pd.Series)): self.tuples[k] = v elif callable(v): self.funcs[k] = v elif isinstance(v, dict): repeat = v['repeat'] self.dicts[k] = v if isinstance(repeat, str): if repeat in self.tuples: self.tuples.move_to_end(repeat, True) if repeat in self.dicts: self.dicts.move_to_end(repeat, True) if repeat in self.funcs: self.funcs.move_to_end(repeat, True) else: self.others[k] = v
[docs] def __call__(self, num: int, *, factory=None): for _ in range(num): tmp = {} for k, v in self.tuples.items(): tmp[k] = self._gen(v) for k, v in self.funcs.items(): tmp[k] = v() for k, v in reversed(self.dicts.items()): data = v['data'] repeat = v['repeat'] if 'replace' in v: replace = v['replace'] else: replace = True if isinstance(repeat, (tuple, list, np.ndarray, pd.Series)): repeat = self._gen(repeat) elif isinstance(repeat, str): repeat = len(tmp[repeat]) if isinstance(data, (tuple, list, np.ndarray, pd.Series)): tmp[k] = self._gen(data, repeat, replace) elif callable(data): tmp[k] = tuple(data(repeat)) tmp = dict(self.others, **tmp) if factory is not None: yield tmp, factory(**tmp) else: yield tmp
@staticmethod def _gen(item: Sequence, repeat: int = None, replace: bool = True): if repeat is not None: idx = np.random.choice(len(item), repeat, replace=replace) return tuple([item[i] for i in idx]) else: idx = np.random.choice(len(item)) return item[idx]