当前位置: 首页 > news >正文

python-函数前一行加@xxxx的含义参数的约束条件检查装饰器

在这里插入图片描述

在sklearn中看到红框中的函数,于是好奇是什么东西,查到python-函数前一行加@xxxx的含义

在这里插入图片描述

于是找到函数定义:def validate_params(parameter_constraints, *, prefer_skip_nested_validation):
在这里插入图片描述

但是,里面没有定义func参数
于是再看到下面,原来这个函数下面又定义了一个def decorator(func):
这样是可以的嘛?
于是去尝试

def test_func():print(1111)def inner_func(func):func()return @test_func()
def some_func():print("pp")return
some_func()

这也不行啊
在这里插入图片描述

进一步了解到,原来:它是通过 functools 重写了装饰器函数,
你要这样写才行

import functools
def test_func():print(1111)### 装饰器函数def decorator(func):@functools.wraps(func)def wrapper(*args, **kwargs):return func(*args, **kwargs)return wrapperreturn decorator@test_func()
def some_func():print("pp")returnsome_func()

下面是具体介绍

@validate_params 装饰器的运作原理

  1. 装饰器定义

    • validate_params 是一个装饰器函数。它的作用是用于验证被装饰函数的参数类型是否符合预设的约束条件。
  2. 参数约束

    • parameter_constraints 是一个字典,用于定义每个参数的允许类型。例如,可以指定某个参数可以是列表或 NumPy 数组。
  3. 内部装饰器

    • decoratorvalidate_params 内部定义的装饰器函数。它接受被装饰的函数 func 作为参数。
  4. 参数绑定

    • wrapper 函数中,使用 signature(func).bind(*args, **kwargs).arguments 将传入的参数与函数的签名进行绑定,生成一个包含所有参数及其值的字典 params
  5. 参数验证

    • 对字典中的每个参数进行检查。使用 any() 函数来判断该参数的值是否符合定义的约束条件。如果不符合,则抛出一个自定义的异常 InvalidParameterError,并提供错误信息。
  6. 调用原函数

    • 如果所有参数都通过了验证,wrapper 函数就会调用原始的被装饰函数 func,并返回其结果。

@validate_params
装饰器的核心功能是自动检查函数参数的类型。这可以帮助开发者在调用函数之前发现潜在的错误,增强代码的健壮性和可维护性。通过这种方式,确保了函数在执行时获得正确类型的输入,从而减少了运行时错误的风险。
我写了一个示例代码:

import functools
import numpy as np
from inspect import signatureclass InvalidParameterError(ValueError):passdef validate_params(parameter_constraints, *, prefer_skip_nested_validation):"""装饰器用于验证函数和方法的参数类型和值。"""def decorator(func):setattr(func, "_parameter_constraints", parameter_constraints)@functools.wraps(func)def wrapper(*args, **kwargs):params = signature(func).bind(*args, **kwargs).argumentsto_ignore = ["self", "cls"]params = {k: v for k, v in params.items() if k not in to_ignore}validate_parameter_constraints(parameter_constraints, params, caller_name=func.__qualname__)return func(*args, **kwargs)return wrapperreturn decoratordef validate_parameter_constraints(parameter_constraints, params, caller_name):for param, constraints in parameter_constraints.items():if param not in params:continuevalue = params[param]valid = Falsefor constraint in constraints:# 检查是否为类型if isinstance(constraint, type):if isinstance(value, constraint):valid = Truebreak# 检查是否为 Noneelif constraint is None and value is None:valid = Truebreakif not valid:expected_types = ', '.join(c.__name__ if isinstance(c, type) else str(c) for c in constraints)raise InvalidParameterError(f"{caller_name}: '{param}' must be one of types: {expected_types}.")@validate_params({"y_true": [list, np.ndarray],"y_pred": [list, np.ndarray],"sample_weight": [list, np.ndarray, None],},prefer_skip_nested_validation=True,
)
def mean_squared_error(y_true, y_pred, *, sample_weight=None):"""计算均方误差 (MSE)。"""if sample_weight is not None:sample_weight = np.array(sample_weight)y_true = np.array(y_true)y_pred = np.array(y_pred)if sample_weight is not None:return np.average((y_pred - y_true) ** 2, weights=sample_weight)else:return np.mean((y_pred - y_true) ** 2)# 示例用法
y_true = [3, -0.5, 2, 7]  # 真实值
y_pred = [2.5, 0.0, 2, 8]  # 预测值
print(mean_squared_error(y_true, y_pred))  # 输出均方误差

结果
在这里插入图片描述

sklearn中源码


def validate_params(parameter_constraints, *, prefer_skip_nested_validation):"""Decorator to validate types and values of functions and methods.Parameters----------parameter_constraints : dictA dictionary `param_name: list of constraints`. See the docstring of`validate_parameter_constraints` for a description of the accepted constraints.Note that the *args and **kwargs parameters are not validated and must not bepresent in the parameter_constraints dictionary.prefer_skip_nested_validation : boolIf True, the validation of parameters of inner estimators or functionscalled by the decorated function will be skipped.This is useful to avoid validating many times the parameters passed by theuser from the public facing API. It's also useful to avoid validatingparameters that we pass internally to inner functions that are guaranteed tobe valid by the test suite.It should be set to True for most functions, except for those that receivenon-validated objects as parameters or that are just wrappers around classesbecause they only perform a partial validation.Returns-------decorated_function : function or methodThe decorated function."""def decorator(func):# The dict of parameter constraints is set as an attribute of the function# to make it possible to dynamically introspect the constraints for# automatic testing.setattr(func, "_skl_parameter_constraints", parameter_constraints)@functools.wraps(func)def wrapper(*args, **kwargs):global_skip_validation = get_config()["skip_parameter_validation"]if global_skip_validation:return func(*args, **kwargs)func_sig = signature(func)# Map *args/**kwargs to the function signatureparams = func_sig.bind(*args, **kwargs)params.apply_defaults()# ignore self/cls and positional/keyword markersto_ignore = [p.namefor p in func_sig.parameters.values()if p.kind in (p.VAR_POSITIONAL, p.VAR_KEYWORD)]to_ignore += ["self", "cls"]params = {k: v for k, v in params.arguments.items() if k not in to_ignore}validate_parameter_constraints(parameter_constraints, params, caller_name=func.__qualname__)try:with config_context(skip_parameter_validation=(prefer_skip_nested_validation or global_skip_validation)):return func(*args, **kwargs)except InvalidParameterError as e:# When the function is just a wrapper around an estimator, we allow# the function to delegate validation to the estimator, but we replace# the name of the estimator by the name of the function in the error# message to avoid confusion.msg = re.sub(r"parameter of \w+ must be",f"parameter of {func.__qualname__} must be",str(e),)raise InvalidParameterError(msg) from ereturn wrapperreturn decorator@validate_params({"y_true": ["array-like"],"y_pred": ["array-like"],"sample_weight": ["array-like", None],"multioutput": [StrOptions({"raw_values", "uniform_average"}), "array-like"],},prefer_skip_nested_validation=True,
)
def mean_absolute_error(y_true, y_pred, *, sample_weight=None, multioutput="uniform_average"
):"""Mean absolute error regression loss.Read more in the :ref:`User Guide <mean_absolute_error>`.Parameters----------y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)Ground truth (correct) target values.y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)Estimated target values.sample_weight : array-like of shape (n_samples,), default=NoneSample weights.multioutput : {'raw_values', 'uniform_average'}  or array-like of shape \(n_outputs,), default='uniform_average'Defines aggregating of multiple output values.Array-like value defines weights used to average errors.'raw_values' :Returns a full set of errors in case of multioutput input.'uniform_average' :Errors of all outputs are averaged with uniform weight.Returns-------loss : float or ndarray of floatsIf multioutput is 'raw_values', then mean absolute error is returnedfor each output separately.If multioutput is 'uniform_average' or an ndarray of weights, then theweighted average of all output errors is returned.MAE output is non-negative floating point. The best value is 0.0.Examples-------->>> from sklearn.metrics import mean_absolute_error>>> y_true = [3, -0.5, 2, 7]>>> y_pred = [2.5, 0.0, 2, 8]>>> mean_absolute_error(y_true, y_pred)0.5>>> y_true = [[0.5, 1], [-1, 1], [7, -6]]>>> y_pred = [[0, 2], [-1, 2], [8, -5]]>>> mean_absolute_error(y_true, y_pred)0.75>>> mean_absolute_error(y_true, y_pred, multioutput='raw_values')array([0.5, 1. ])>>> mean_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])0.85..."""y_type, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput)check_consistent_length(y_true, y_pred, sample_weight)output_errors = np.average(np.abs(y_pred - y_true), weights=sample_weight, axis=0)if isinstance(multioutput, str):if multioutput == "raw_values":return output_errorselif multioutput == "uniform_average":# pass None as weights to np.average: uniform meanmultioutput = Nonereturn np.average(output_errors, weights=multioutput)

http://www.mrgr.cn/news/63572.html

相关文章:

  • 《废品机械师抢先版》V0.7.3.b776官方中文学习版
  • 提升租赁效率的租赁小程序全解析
  • 相机和激光雷达的外参标定 - 无标定板版本
  • maven多模块项目编译一直报Failure to find com.xxx.xxx:xxx-xxx-xxx:pom:1.0-SNAPSHOT in问题
  • uniapp 微信小程序内嵌h5实时通信
  • [免费]微信小程序(高校就业)招聘系统(Springboot后端+Vue管理端)【论文+源码+SQL脚本】
  • 数字后端零基础入门系列 | Innovus零基础LAB学习Day8
  • 使用Linux连接阿里云
  • 动态规划-回文串问题——5.最长回文子串
  • 【UML】- 用例图(结合银行案例解释其中的奥义)
  • 残差块(Residual Block)
  • [每日一练]分组后元素最多的组别(all函数的全局比对)
  • 品牌怎么找到用户发的优质内容,进行加热、复制?
  • YOLO——yolo v4(1)
  • 修改Windows远程桌面3389端口
  • 1008:计算(a+b)/c的值
  • 【视频】OpenCV:识别颜色、绘制轮廓
  • 文本文件、二进制文件常见格式
  • 【分立元件】贴片电阻过电压故障机理
  • 【BUG分析】clickhouse表final成功,但存在数据未合并
  • Java: 遍历 Map
  • 优化宝典:数据库性能提升指南
  • 脉冲当量计算方法
  • HJ53 杨辉三角的变形
  • Java 21 新特性来支持并发编程
  • 2024 年 11 月 1 日 deepin 23 内测更新公告