python-函数前一行加@xxxx的含义参数的约束条件检查装饰器
在sklearn中看到红框中的函数,于是好奇是什么东西,查到python-函数前一行加@xxxx的含义
于是找到函数定义:
def validate_params(parameter_constraints, *, prefer_skip_nested_validation):
但是,里面没有定义
func参数
于是再看到下面,原来这个函数下面又定义了一个def decorator(func):
这样是可以的嘛?
于是去尝试
def test_func():print(1111)def inner_func(func):func()return @test_func()
def some_func():print("pp")return
some_func()
这也不行啊
进一步了解到,原来:它是通过
functools
重写了装饰器函数,
你要这样写才行
import functools
def test_func():print(1111)### 装饰器函数def decorator(func):@functools.wraps(func)def wrapper(*args, **kwargs):return func(*args, **kwargs)return wrapperreturn decorator@test_func()
def some_func():print("pp")returnsome_func()
下面是具体介绍
@validate_params
装饰器的运作原理
装饰器定义:
validate_params
是一个装饰器函数。它的作用是用于验证被装饰函数的参数类型是否符合预设的约束条件。参数约束:
parameter_constraints
是一个字典,用于定义每个参数的允许类型。例如,可以指定某个参数可以是列表或 NumPy 数组。内部装饰器:
decorator
是validate_params
内部定义的装饰器函数。它接受被装饰的函数func
作为参数。参数绑定:
- 在
wrapper
函数中,使用signature(func).bind(*args, **kwargs).arguments
将传入的参数与函数的签名进行绑定,生成一个包含所有参数及其值的字典params
。参数验证:
- 对字典中的每个参数进行检查。使用
any()
函数来判断该参数的值是否符合定义的约束条件。如果不符合,则抛出一个自定义的异常InvalidParameterError
,并提供错误信息。调用原函数:
- 如果所有参数都通过了验证,
wrapper
函数就会调用原始的被装饰函数func
,并返回其结果。
@validate_params
装饰器的核心功能是自动检查函数参数的类型。这可以帮助开发者在调用函数之前发现潜在的错误,增强代码的健壮性和可维护性。通过这种方式,确保了函数在执行时获得正确类型的输入,从而减少了运行时错误的风险。
我写了一个示例代码:
import functools
import numpy as np
from inspect import signatureclass InvalidParameterError(ValueError):passdef validate_params(parameter_constraints, *, prefer_skip_nested_validation):"""装饰器用于验证函数和方法的参数类型和值。"""def decorator(func):setattr(func, "_parameter_constraints", parameter_constraints)@functools.wraps(func)def wrapper(*args, **kwargs):params = signature(func).bind(*args, **kwargs).argumentsto_ignore = ["self", "cls"]params = {k: v for k, v in params.items() if k not in to_ignore}validate_parameter_constraints(parameter_constraints, params, caller_name=func.__qualname__)return func(*args, **kwargs)return wrapperreturn decoratordef validate_parameter_constraints(parameter_constraints, params, caller_name):for param, constraints in parameter_constraints.items():if param not in params:continuevalue = params[param]valid = Falsefor constraint in constraints:# 检查是否为类型if isinstance(constraint, type):if isinstance(value, constraint):valid = Truebreak# 检查是否为 Noneelif constraint is None and value is None:valid = Truebreakif not valid:expected_types = ', '.join(c.__name__ if isinstance(c, type) else str(c) for c in constraints)raise InvalidParameterError(f"{caller_name}: '{param}' must be one of types: {expected_types}.")@validate_params({"y_true": [list, np.ndarray],"y_pred": [list, np.ndarray],"sample_weight": [list, np.ndarray, None],},prefer_skip_nested_validation=True,
)
def mean_squared_error(y_true, y_pred, *, sample_weight=None):"""计算均方误差 (MSE)。"""if sample_weight is not None:sample_weight = np.array(sample_weight)y_true = np.array(y_true)y_pred = np.array(y_pred)if sample_weight is not None:return np.average((y_pred - y_true) ** 2, weights=sample_weight)else:return np.mean((y_pred - y_true) ** 2)# 示例用法
y_true = [3, -0.5, 2, 7] # 真实值
y_pred = [2.5, 0.0, 2, 8] # 预测值
print(mean_squared_error(y_true, y_pred)) # 输出均方误差
结果
sklearn中源码
def validate_params(parameter_constraints, *, prefer_skip_nested_validation):"""Decorator to validate types and values of functions and methods.Parameters----------parameter_constraints : dictA dictionary `param_name: list of constraints`. See the docstring of`validate_parameter_constraints` for a description of the accepted constraints.Note that the *args and **kwargs parameters are not validated and must not bepresent in the parameter_constraints dictionary.prefer_skip_nested_validation : boolIf True, the validation of parameters of inner estimators or functionscalled by the decorated function will be skipped.This is useful to avoid validating many times the parameters passed by theuser from the public facing API. It's also useful to avoid validatingparameters that we pass internally to inner functions that are guaranteed tobe valid by the test suite.It should be set to True for most functions, except for those that receivenon-validated objects as parameters or that are just wrappers around classesbecause they only perform a partial validation.Returns-------decorated_function : function or methodThe decorated function."""def decorator(func):# The dict of parameter constraints is set as an attribute of the function# to make it possible to dynamically introspect the constraints for# automatic testing.setattr(func, "_skl_parameter_constraints", parameter_constraints)@functools.wraps(func)def wrapper(*args, **kwargs):global_skip_validation = get_config()["skip_parameter_validation"]if global_skip_validation:return func(*args, **kwargs)func_sig = signature(func)# Map *args/**kwargs to the function signatureparams = func_sig.bind(*args, **kwargs)params.apply_defaults()# ignore self/cls and positional/keyword markersto_ignore = [p.namefor p in func_sig.parameters.values()if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD)]to_ignore += ["self", "cls"]params = {k: v for k, v in params.arguments.items() if k not in to_ignore}validate_parameter_constraints(parameter_constraints, params, caller_name=func.__qualname__)try:with config_context(skip_parameter_validation=(prefer_skip_nested_validation or global_skip_validation)):return func(*args, **kwargs)except InvalidParameterError as e:# When the function is just a wrapper around an estimator, we allow# the function to delegate validation to the estimator, but we replace# the name of the estimator by the name of the function in the error# message to avoid confusion.msg = re.sub(r"parameter of \w+ must be",f"parameter of {func.__qualname__} must be",str(e),)raise InvalidParameterError(msg) from ereturn wrapperreturn decorator@validate_params({"y_true": ["array-like"],"y_pred": ["array-like"],"sample_weight": ["array-like", None],"multioutput": [StrOptions({"raw_values", "uniform_average"}), "array-like"],},prefer_skip_nested_validation=True,
)
def mean_absolute_error(y_true, y_pred, *, sample_weight=None, multioutput="uniform_average"
):"""Mean absolute error regression loss.Read more in the :ref:`User Guide <mean_absolute_error>`.Parameters----------y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)Ground truth (correct) target values.y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)Estimated target values.sample_weight : array-like of shape (n_samples,), default=NoneSample weights.multioutput : {'raw_values', 'uniform_average'} or array-like of shape \(n_outputs,), default='uniform_average'Defines aggregating of multiple output values.Array-like value defines weights used to average errors.'raw_values' :Returns a full set of errors in case of multioutput input.'uniform_average' :Errors of all outputs are averaged with uniform weight.Returns-------loss : float or ndarray of floatsIf multioutput is 'raw_values', then mean absolute error is returnedfor each output separately.If multioutput is 'uniform_average' or an ndarray of weights, then theweighted average of all output errors is returned.MAE output is non-negative floating point. The best value is 0.0.Examples-------->>> from sklearn.metrics import mean_absolute_error>>> y_true = [3, -0.5, 2, 7]>>> y_pred = [2.5, 0.0, 2, 8]>>> mean_absolute_error(y_true, y_pred)0.5>>> y_true = [[0.5, 1], [-1, 1], [7, -6]]>>> y_pred = [[0, 2], [-1, 2], [8, -5]]>>> mean_absolute_error(y_true, y_pred)0.75>>> mean_absolute_error(y_true, y_pred, multioutput='raw_values')array([0.5, 1. ])>>> mean_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])0.85..."""y_type, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput)check_consistent_length(y_true, y_pred, sample_weight)output_errors = np.average(np.abs(y_pred - y_true), weights=sample_weight, axis=0)if isinstance(multioutput, str):if multioutput == "raw_values":return output_errorselif multioutput == "uniform_average":# pass None as weights to np.average: uniform meanmultioutput = Nonereturn np.average(output_errors, weights=multioutput)