Source code for gefest.core.configs.tuner_params

import inspect
from ctypes import Structure
from typing import Callable, Union

from hyperopt import hp
from pydantic import BaseModel, ConfigDict, field_validator

from gefest.tools.tuners import utils


[docs] class TunerParams(BaseModel): """Dataclass for GolemTuner parameters aggreagtion. Provides easy configuration for tuner with built-in validation and serialization. """ tuner_type: str """Type of GOLEM tuners to use. Available: 'iopt', 'optuna', 'sequential', 'simulataneous'. For tuner details see: https://thegolem.readthedocs.io/en/latest/api/tuning.html https://fedot.readthedocs.io/en/latest/advanced/hyperparameters_tuning.html """ n_steps_tune: int """Number of tuner steps.""" tune_n_best: int = 1 """Top tune_n_best structures from provided population to tune.""" hyperopt_dist: Union[Callable, str] = hp.uniform """Random distribution function.""" verbose: bool = True """GOLEM console info.""" variacne_generator: Union[Callable[[Structure], list[float]], str] = utils.percent_edge_variance """The function for generating the search space includes intervals for each component of each point of each polygon in the provided structure. Output format should be spicific dict configuration. For details see: https://thegolem.readthedocs.io/en/latest/api/tuning.html """ timeout_minutes: int = 60 """GOLEM argument."""
[docs] @field_validator('tuner_type') @classmethod def tuner_type_validate(cls, value): """Checks if specified tuner exists.""" if isinstance(value, str): opt_names = ['iopt', 'optuna', 'sequential', 'simulataneous'] if value in opt_names: return value else: raise ValueError(f'Invalid distribution name: {value}. Allowed names: {opt_names}') else: raise ValueError(f'Invalid argument: {value} of type {type(value)}.')
[docs] @field_validator('hyperopt_dist') @classmethod def hyperopt_fun_validate(cls, value): """Checks if hyperopt distribution function exists.""" if isinstance(value, str): r_ = inspect.getmembers(hp, inspect.isfunction) fun_names = [i[0] for i in r_] if value in fun_names: return getattr(hp, value) else: raise ValueError(f'Invalid distribution name: {value}. Allowed names: {fun_names}') elif isinstance(value, Callable): if value.__module__.split('.')[0] == hp.__name__.split('.')[0]: return value else: raise ValueError(f'Invalid argument: {value} of type {type(value)}.')
[docs] @field_validator('variacne_generator') @classmethod def variacne_generator_fun_validate(cls, value): """Checks if specified variance generation function exists.""" fun_names = ['percent_edge_variance'] if isinstance(value, str): if value in fun_names: return getattr(utils, value) elif isinstance(value, Callable): return value else: raise ValueError(f'Invalid argument: {value} of type {type(value)}.')
model_config = ConfigDict({'arbitrary_types_allowed': True})