huggingface/pytorch-image-models
huggingface/pytorch-image-models
1. 使用技巧
1.1.训练指令
单卡:
python train.py --pretrained --input-size 3 224 224 --mean 0 0 0 --std 1 1 1 --batch-size 128 --validation-batch-size 128 --color-jitter-prob 0.2 --grayscale-prob 0.2 --gaussian-blur-prob 0.2 --save-images
多卡,下面参数的4表示4块卡一起训练:
sh distributed_train.sh 4 --pretrained --input-size 3 224 224 --mean 0 0 0 --std 1 1 1 --batch-size 64 --validation-batch-size 64 --color-jitter-prob 0.5 --grayscale-prob 0.2 --gaussian-blur-prob 0.2 --save-images
多卡的另一种形式,更改监听的端口号:
python -m torch.distributed.launch --nproc_per_node=3 --master_port=29501 train_v2.py --pretrained --input-size 3 224 224 --mean 0 0 0 --std 1 1 1 --batch-size 64 --validation-batch-size 64 --color-jitter-prob 0.5 --grayscale-prob 0.2 --gaussian-blur-prob 0.2 --save-images
1.2.模型转ONNX
python onnx_export.py huggingface\pytorch-image-models\output\train\20240529-132242-vit_base_patch16_224-224\model_best.onnx --mean 0 0 0 --std 1 1 1 --img-size 224 --checkpoint huggingface\pytorch-image-models\output\train\20240529-132242-vit_base_patch16_224-224\model_best.pth.tar
1.3. 分类网络数据
训练集组织形式如yolov8_cls:
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
sh distributed_train_v2.sh 4 --pretrained --input-size 3 224 224 --mean 0 0 0 --std 1 1 1 --batch-size 64 --validation-batch-size 64 --color-jitter-prob 0.5 --grayscale-prob 0.2 --gaussian-blur-prob 0.2 --save-images
1.4. 修改分类train.py
#!/usr/bin/env python3
""" ImageNet Training ScriptThis is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import argparse
import importlib
import json
import logging
import os
import time
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from functools import partialimport torch
import torch.nn as nn
import torchvision.utils
import yaml
from torch.nn.parallel import DistributedDataParallel as NativeDDPfrom timm import utils
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
from timm.utils import ApexScaler, NativeScalerimport os
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'try:from apex import ampfrom apex.parallel import DistributedDataParallel as ApexDDPfrom apex.parallel import convert_syncbn_modelhas_apex = True
except ImportError:has_apex = Falsehas_native_amp = False
try:if getattr(torch.cuda.amp, 'autocast') is not None:has_native_amp = True
except AttributeError:passtry:import wandbhas_wandb = True
except ImportError:has_wandb = Falsetry:from functorch.compile import memory_efficient_fusionhas_functorch = True
except ImportError as e:has_functorch = Falsehas_compile = hasattr(torch, 'compile')_logger = logging.getLogger('train')# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',help='YAML config file specifying default arguments')parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')# Dataset parameters
group = parser.add_argument_group('Dataset parameters')
# Keep this argument outside the dataset group because it is positional.
parser.add_argument('data', nargs='?', metavar='DIR', const=None,help='path to dataset (positional is *deprecated*, use --data-dir)')
parser.add_argument('--data-dir', metavar='DIR', default=r'/media/lg/C2032F933B04C4E6/00.Data/009.Uniform/81.version-2024.05.25/00.train_224_224',help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
group.add_argument('--train-split', metavar='NAME', default='train',help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',help='dataset validation split (default: validation)')
parser.add_argument('--train-num-samples', default=None, type=int,metavar='N', help='Manually specify num samples in train split, for IterableDatasets.')
parser.add_argument('--val-num-samples', default=None, type=int,metavar='N', help='Manually specify num samples in validation split, for IterableDatasets.')
group.add_argument('--dataset-download', action='store_true', default=False,help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',help='path to class to idx mapping file (default: "")')
group.add_argument('--input-img-mode', default=None, type=str,help='Dataset image conversion mode for input images.')
group.add_argument('--input-key', default=None, type=str,help='Dataset key for input images.')
group.add_argument('--target-key', default=None, type=str,help='Dataset key for target labels.')# Model parameters
group = parser.add_argument_group('Model parameters')
group.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL',help='Name of model to train (default: "resnet50")')
group.add_argument('--pretrained', action='store_true', default=False,help='Start with pretrained version of specified network (if avail)')
group.add_argument('--pretrained-path', default='/home/test/pytorch-image-models/output/train/20240528-142446-vit_base_patch16_224-224/last.pth.tar', type=str,help='Load this checkpoint as if they were the pretrained weights (with adaptation).')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',help='Load this checkpoint into model after initialization (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH',help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False,help='prevent resume of optimizer state when resuming model')
group.add_argument('--num-classes', type=int, default=3000, metavar='N',help='number of label classes (Model default if None)')
group.add_argument('--gp', default=None, type=str, metavar='POOL',help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N',help='Image size (default: None => model default)')
group.add_argument('--in-chans', type=int, default=None, metavar='N',help='Image input channels (default: None => 3)')
group.add_argument('--input-size', default=None, nargs=3, type=int,metavar='N N N',help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
group.add_argument('--crop-pct', default=1.0, type=float,metavar='N', help='Input image center crop percent (for validation only)')
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',help='Override mean pixel value of dataset')
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',help='Override std deviation of dataset')
group.add_argument('--interpolation', default='', type=str, metavar='NAME',help='Image resize interpolation type (overrides model)')
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',help='Input batch size for training (default: 128)')
group.add_argument('-vb', '--validation-batch-size', type=int, default=128, metavar='N',help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False,help='Use channels_last memory layout')
group.add_argument('--fuser', default='', type=str,help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N',help='The number of steps to accumulate gradients (default: 1)')
group.add_argument('--grad-checkpointing', action='store_true', default=False,help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true',help='enable experimental fast-norm')
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
group.add_argument('--head-init-scale', default=None, type=float,help='Head initialization scale')
group.add_argument('--head-init-bias', default=None, type=float,help='Head initialization bias value')# scripting / codegen
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',help="Enable compilation w/ specified backend (default: inductor).")# Device & distributed
group = parser.add_argument_group('Device parameters')
group.add_argument('--device', default='cuda', type=str,help="Device (accelerator) to use.")
group.add_argument('--amp', action='store_true', default=False,help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,help='torch.cuda.synchronize() end of each step')
group.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--device-modules', default=None, type=str, nargs='+',help="Python imports for device backend modules.")# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',help='Optimizer (default: "sgd")')
group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',help='Optimizer Epsilon (default: None, use opt default)')
group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',help='Optimizer Betas (default: None, use opt default)')
group.add_argument('--momentum', type=float, default=0.9, metavar='M',help='Optimizer momentum (default: 0.9)')
group.add_argument('--weight-decay', type=float, default=2e-5,help='weight decay (default: 2e-5)')
group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',help='Clip gradient norm (default: None, no clipping)')
group.add_argument('--clip-mode', type=str, default='norm',help='Gradient clipping mode. One of ("norm", "value", "agc")')
group.add_argument('--layer-decay', type=float, default=None,help='layer-wise learning rate decay (default: None)')
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)# Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',help='LR scheduler (default: "step"')
group.add_argument('--sched-on-updates', action='store_true', default=False,help='Apply LR scheduler step on update instead of epoch end.')
group.add_argument('--lr', type=float, default=None, metavar='LR',help='learning rate, overrides lr-base if set (default: None)')
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',help='base learning rate: lr = lr_base * global_batch_size / base_size')
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',help='base learning rate batch size (divisor, default: 256).')
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',help='learning rate noise limit percent (default: 0.67)')
group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',help='learning rate noise std-dev (default: 1.0)')
group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',help='learning rate cycle len multiplier (default: 1.0)')
group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',help='amount to decay each learning rate cycle (default: 0.5)')
group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0,help='learning rate k-decay for cosine/poly (default: 1.0)')
group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',help='warmup learning rate (default: 1e-5)')
group.add_argument('--min-lr', type=float, default=0, metavar='LR',help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
group.add_argument('--epochs', type=int, default=300, metavar='N',help='number of epochs to train (default: 300)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N',help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=90, metavar='N',help='epoch interval to decay LR')
group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',help='epochs to warmup LR, if scheduler supports')
group.add_argument('--warmup-prefix', action='store_true', default=False,help='Exclude warmup period from decay schedule.'),
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',help='patience epochs for Plateau LR scheduler (default: 10)')
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',help='LR decay rate (default: 0.1)')# Augmentation & regularization parameters
group = parser.add_argument_group('Augmentation and regularization parameters')
group.add_argument('--no-aug', action='store_true', default=False,help='Disable all training augmentation, override other train aug args')
group.add_argument('--train-crop-mode', type=str, default=None,help='Crop-mode in train'),
group.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',help='Random resize aspect ratio (default: 0.75 1.33)')
group.add_argument('--hflip', type=float, default=0.5,help='Horizontal flip training aug probability')
group.add_argument('--vflip', type=float, default=0.5,help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',help='Color jitter factor (default: 0.4)')
group.add_argument('--color-jitter-prob', type=float, default=None, metavar='PCT',help='Probability of applying any color jitter.')
group.add_argument('--grayscale-prob', type=float, default=None, metavar='PCT',help='Probability of applying random grayscale conversion.')
group.add_argument('--gaussian-blur-prob', type=float, default=None, metavar='PCT',help='Probability of applying gaussian blur.')
group.add_argument('--aa', type=str, default=None, metavar='NAME',help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0,help='Number of augmentation repetitions (distributed training only) (default: 0)')
group.add_argument('--aug-splits', type=int, default=0,help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
group.add_argument('--jsd-loss', action='store_true', default=False,help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
group.add_argument('--bce-loss', action='store_true', default=False,help='Enable BCE loss w/ Mixup/CutMix use.')
group.add_argument('--bce-sum', action='store_true', default=False,help='Sum over classes when using BCE loss.')
group.add_argument('--bce-target-thresh', type=float, default=None,help='Threshold for binarizing softened BCE targets (default: None, disabled).')
group.add_argument('--bce-pos-weight', type=float, default=None,help='Positive weighting for BCE loss.')
group.add_argument('--reprob', type=float, default=0., metavar='PCT',help='Random erase prob (default: 0.)')
group.add_argument('--remode', type=str, default='pixel',help='Random erase mode (default: "pixel")')
group.add_argument('--recount', type=int, default=1,help='Random erase count (default: 1)')
group.add_argument('--resplit', action='store_true', default=False,help='Do not random erase first (clean) augmentation split')
group.add_argument('--mixup', type=float, default=0.0,help='mixup alpha, mixup enabled if > 0. (default: 0.)')
group.add_argument('--cutmix', type=float, default=0.0,help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
group.add_argument('--mixup-prob', type=float, default=1.0,help='Probability of performing mixup or cutmix when either/both is enabled')
group.add_argument('--mixup-switch-prob', type=float, default=0.5,help='Probability of switching to cutmix when both mixup and cutmix enabled')
group.add_argument('--mixup-mode', type=str, default='batch',help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
group.add_argument('--smoothing', type=float, default=0.1,help='Label smoothing (default: 0.1)')
group.add_argument('--train-interpolation', type=str, default='random',help='Training interpolation (random, bilinear, bicubic default: "random")')
group.add_argument('--drop', type=float, default=0.0, metavar='PCT',help='Dropout rate (default: 0.)')
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
group.add_argument('--drop-path', type=float, default=None, metavar='PCT',help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',help='Drop block rate (default: None)')# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
group.add_argument('--bn-momentum', type=float, default=None,help='BatchNorm momentum override (if not None)')
group.add_argument('--bn-eps', type=float, default=None,help='BatchNorm epsilon override (if not None)')
group.add_argument('--sync-bn', action='store_true',help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
group.add_argument('--dist-bn', type=str, default='reduce',help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
group.add_argument('--split-bn', action='store_true',help='Enable separate BN layers per augmentation split.')# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False,help='Enable tracking moving average of model weights.')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998,help='Decay factor for model weights moving average (default: 0.9998)')
group.add_argument('--model-ema-warmup', action='store_true',help='Enable warmup for model EMA decay.')# Misc
group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--seed', type=int, default=42, metavar='S',help='random seed (default: 42)')
group.add_argument('--worker-seeding', type=str, default='all',help='worker seed mode (default: all)')
group.add_argument('--log-interval', type=int, default=50, metavar='N',help='how many batches to wait before logging training status')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',help='how many batches to wait before writing recovery checkpoint')
group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',help='number of checkpoints to keep (default: 10)')
group.add_argument('-j', '--workers', type=int, default=4, metavar='N',help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False,help='save images of input bathes every log interval for debugging')
group.add_argument('--pin-mem', action='store_true', default=False,help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,help='disable fast prefetcher')
group.add_argument('--output', default='', type=str, metavar='PATH',help='path to output folder (default: none, current dir)')
group.add_argument('--experiment', default='', type=str, metavar='NAME',help='name of train experiment, name of sub-folder for output')
group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N',help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,help='log training and validation metrics to wandb')def _parse_args():# Do we have a config file to parse?args_config, remaining = config_parser.parse_known_args()if args_config.config:with open(args_config.config, 'r') as f:cfg = yaml.safe_load(f)parser.set_defaults(**cfg)# The main arg parser parses the rest of the args, the usual# defaults will have been overridden if config file specified.args = parser.parse_args(remaining)# Cache the args as a text string to save them in the output dir laterargs_text = yaml.safe_dump(args.__dict__, default_flow_style=False)return args, args_textdef main():utils.setup_default_logging()args, args_text = _parse_args()if args.device_modules:for module in args.device_modules:importlib.import_module(module)if torch.cuda.is_available():torch.backends.cuda.matmul.allow_tf32 = Truetorch.backends.cudnn.benchmark = Trueargs.prefetcher = not args.no_prefetcherargs.grad_accum_steps = max(1, args.grad_accum_steps)device = utils.init_distributed_device(args)if args.distributed:_logger.info('Training in distributed mode with multiple processes, 1 device per process.'f'Process {args.rank}, total {args.world_size}, device {args.device}.')else:_logger.info(f'Training with a single process on 1 device ({args.device}).')assert args.rank >= 0# resolve AMP arguments based on PyTorch / Apex availabilityuse_amp = Noneamp_dtype = torch.float16if args.amp:if args.amp_impl == 'apex':assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'use_amp = 'apex'assert args.amp_dtype == 'float16'else:assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'use_amp = 'native'assert args.amp_dtype in ('float16', 'bfloat16')if args.amp_dtype == 'bfloat16':amp_dtype = torch.bfloat16utils.random_seed(args.seed, args.rank)if args.fuser:utils.set_jit_fuser(args.fuser)if args.fast_norm:set_fast_norm()in_chans = 3if args.in_chans is not None:in_chans = args.in_chanselif args.input_size is not None:in_chans = args.input_size[0]factory_kwargs = {}if args.pretrained_path:# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.factory_kwargs['pretrained_cfg_overlay'] = dict(file=args.pretrained_path,num_classes=-1, # force head adaptation)model = create_model(args.model,pretrained=args.pretrained,in_chans=in_chans,num_classes=args.num_classes,drop_rate=args.drop,drop_path_rate=args.drop_path,drop_block_rate=args.drop_block,global_pool=args.gp,bn_momentum=args.bn_momentum,bn_eps=args.bn_eps,scriptable=args.torchscript,checkpoint_path=args.initial_checkpoint,**factory_kwargs,**args.model_kwargs,)if args.head_init_scale is not None:with torch.no_grad():model.get_classifier().weight.mul_(args.head_init_scale)model.get_classifier().bias.mul_(args.head_init_scale)if args.head_init_bias is not None:nn.init.constant_(model.get_classifier().bias, args.head_init_bias)if args.num_classes is None:assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantlyif args.grad_checkpointing:model.set_grad_checkpointing(enable=True)if utils.is_primary(args):_logger.info(f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))# setup augmentation batch splits for contrastive loss or split bnnum_aug_splits = 0if args.aug_splits > 0:assert args.aug_splits > 1, 'A split of 1 makes no sense'num_aug_splits = args.aug_splits# enable split bn (separate bn stats per batch-portion)if args.split_bn:assert num_aug_splits > 1 or args.resplitmodel = convert_splitbn_model(model, max(num_aug_splits, 2))# move model to GPU, enable channels last layout if setmodel.to(device=device)if args.channels_last:model.to(memory_format=torch.channels_last)# setup synchronized BatchNorm for distributed trainingif args.distributed and args.sync_bn:args.dist_bn = '' # disable dist_bn when sync BN activeassert not args.split_bnif has_apex and use_amp == 'apex':# Apex SyncBN used with Apex AMP# WARNING this won't currently work with models using BatchNormAct2dmodel = convert_syncbn_model(model)else:model = convert_sync_batchnorm(model)if utils.is_primary(args):_logger.info('Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ''zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')if args.torchscript:assert not args.torchcompileassert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'model = torch.jit.script(model)if not args.lr:global_batch_size = args.batch_size * args.world_size * args.grad_accum_stepsbatch_ratio = global_batch_size / args.lr_base_sizeif not args.lr_base_scale:on = args.opt.lower()args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'if args.lr_base_scale == 'sqrt':batch_ratio = batch_ratio ** 0.5args.lr = args.lr_base * batch_ratioif utils.is_primary(args):_logger.info(f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) 'f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')optimizer = create_optimizer_v2(model,**optimizer_kwargs(cfg=args),**args.opt_kwargs,)# setup automatic mixed-precision (AMP) loss scaling and op castingamp_autocast = suppress # do nothingloss_scaler = Noneif use_amp == 'apex':assert device.type == 'cuda'model, optimizer = amp.initialize(model, optimizer, opt_level='O1')loss_scaler = ApexScaler()if utils.is_primary(args):_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')elif use_amp == 'native':try:amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)except (AttributeError, TypeError):# fallback to CUDA only AMP for PyTorch < 1.10assert device.type == 'cuda'amp_autocast = torch.cuda.amp.autocastif device.type == 'cuda' and amp_dtype == torch.float16:# loss scaler only used for float16 (half) dtype, bfloat16 does not need itloss_scaler = NativeScaler()if utils.is_primary(args):_logger.info('Using native Torch AMP. Training in mixed precision.')else:if utils.is_primary(args):_logger.info('AMP not enabled. Training in float32.')# optionally resume from a checkpointresume_epoch = Noneif args.resume:resume_epoch = resume_checkpoint(model,args.resume,optimizer=None if args.no_resume_opt else optimizer,loss_scaler=None if args.no_resume_opt else loss_scaler,log_info=utils.is_primary(args),)# setup exponential moving average of model weights, SWA could be used here toomodel_ema = Noneif args.model_ema:# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrappermodel_ema = utils.ModelEmaV3(model,decay=args.model_ema_decay,use_warmup=args.model_ema_warmup,device='cpu' if args.model_ema_force_cpu else None,)if args.resume:load_checkpoint(model_ema.module, args.resume, use_ema=True)if args.torchcompile:model_ema = torch.compile(model_ema, backend=args.torchcompile)# setup distributed trainingif args.distributed:if has_apex and use_amp == 'apex':# Apex DDP preferred unless native amp is activatedif utils.is_primary(args):_logger.info("Using NVIDIA APEX DistributedDataParallel.")model = ApexDDP(model, delay_allreduce=True)else:if utils.is_primary(args):_logger.info("Using native Torch DistributedDataParallel.")model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)# NOTE: EMA model does not need to be wrapped by DDPif args.torchcompile:# torch compile should be done after DDPassert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'model = torch.compile(model, backend=args.torchcompile)# create the train and eval datasetsif args.data and not args.data_dir:args.data_dir = args.dataif args.input_img_mode is None:input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'else:input_img_mode = args.input_img_modedataset_train = create_dataset(args.dataset,root=args.data_dir,split=args.train_split,is_training=True,class_map=args.class_map,download=args.dataset_download,batch_size=args.batch_size,seed=args.seed,repeats=args.epoch_repeats,input_img_mode=input_img_mode,input_key=args.input_key,target_key=args.target_key,num_samples=args.train_num_samples,)if args.val_split:dataset_eval = create_dataset(args.dataset,root=args.data_dir,split=args.val_split,is_training=False,class_map=args.class_map,download=args.dataset_download,batch_size=args.batch_size,input_img_mode=input_img_mode,input_key=args.input_key,target_key=args.target_key,num_samples=args.val_num_samples,)# setup mixup / cutmixcollate_fn = Nonemixup_fn = Nonemixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not Noneif mixup_active:mixup_args = dict(mixup_alpha=args.mixup,cutmix_alpha=args.cutmix,cutmix_minmax=args.cutmix_minmax,prob=args.mixup_prob,switch_prob=args.mixup_switch_prob,mode=args.mixup_mode,label_smoothing=args.smoothing,num_classes=args.num_classes)if args.prefetcher:assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup)collate_fn = FastCollateMixup(**mixup_args)else:mixup_fn = Mixup(**mixup_args)# wrap dataset in AugMix helperif num_aug_splits > 1:dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)# create data loaders w/ augmentation pipelinetrain_interpolation = args.train_interpolationif args.no_aug or not train_interpolation:train_interpolation = data_config['interpolation']loader_train = create_loader(dataset_train,input_size=data_config['input_size'],batch_size=args.batch_size,is_training=True,no_aug=args.no_aug,re_prob=args.reprob,re_mode=args.remode,re_count=args.recount,re_split=args.resplit,train_crop_mode=args.train_crop_mode,scale=args.scale,ratio=args.ratio,hflip=args.hflip,vflip=args.vflip,color_jitter=args.color_jitter,color_jitter_prob=args.color_jitter_prob,grayscale_prob=args.grayscale_prob,gaussian_blur_prob=args.gaussian_blur_prob,auto_augment=args.aa,num_aug_repeats=args.aug_repeats,num_aug_splits=num_aug_splits,interpolation=train_interpolation,mean=data_config['mean'],std=data_config['std'],num_workers=args.workers,distributed=args.distributed,collate_fn=collate_fn,pin_memory=args.pin_mem,device=device,use_prefetcher=args.prefetcher,use_multi_epochs_loader=args.use_multi_epochs_loader,worker_seeding=args.worker_seeding,)loader_eval = Noneif args.val_split:eval_workers = args.workersif args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed trainingeval_workers = min(2, args.workers)loader_eval = create_loader(dataset_eval,input_size=data_config['input_size'],batch_size=args.validation_batch_size or args.batch_size,is_training=False,interpolation=data_config['interpolation'],mean=data_config['mean'],std=data_config['std'],num_workers=eval_workers,distributed=args.distributed,crop_pct=data_config['crop_pct'],pin_memory=args.pin_mem,device=device,use_prefetcher=args.prefetcher,)# setup loss functionif args.jsd_loss:assert num_aug_splits > 1 # JSD only valid with aug splits settrain_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)elif mixup_active:# smoothing is handled with mixup target transform which outputs sparse, soft targetsif args.bce_loss:train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh,sum_classes=args.bce_sum,pos_weight=args.bce_pos_weight,)else:train_loss_fn = SoftTargetCrossEntropy()elif args.smoothing:if args.bce_loss:train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing,target_threshold=args.bce_target_thresh,sum_classes=args.bce_sum,pos_weight=args.bce_pos_weight,)else:train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)else:train_loss_fn = nn.CrossEntropyLoss()train_loss_fn = train_loss_fn.to(device=device)validate_loss_fn = nn.CrossEntropyLoss().to(device=device)# setup checkpoint saver and eval metric trackingeval_metric = args.eval_metric if loader_eval is not None else 'loss'decreasing_metric = eval_metric == 'loss'best_metric = Nonebest_epoch = Nonesaver = Noneoutput_dir = Noneif utils.is_primary(args):if args.experiment:exp_name = args.experimentelse:exp_name = '-'.join([datetime.now().strftime("%Y%m%d-%H%M%S"),safe_model_name(args.model),str(data_config['input_size'][-1])])output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)saver = utils.CheckpointSaver(model=model,optimizer=optimizer,args=args,model_ema=model_ema,amp_scaler=loss_scaler,checkpoint_dir=output_dir,recovery_dir=output_dir,decreasing=decreasing_metric,max_history=args.checkpoint_hist)with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:f.write(args_text)if utils.is_primary(args) and args.log_wandb:if has_wandb:wandb.init(project=args.experiment, config=args)else:_logger.warning("You've requested to log metrics to wandb but package not found. ""Metrics not being logged to wandb, try `pip install wandb`")# setup learning rate schedule and starting epochupdates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_stepslr_scheduler, num_epochs = create_scheduler_v2(optimizer,**scheduler_kwargs(args, decreasing_metric=decreasing_metric),updates_per_epoch=updates_per_epoch,)start_epoch = 0if args.start_epoch is not None:# a specified start_epoch will always override the resume epochstart_epoch = args.start_epochelif resume_epoch is not None:start_epoch = resume_epochif lr_scheduler is not None and start_epoch > 0:if args.sched_on_updates:lr_scheduler.step_update(start_epoch * updates_per_epoch)else:lr_scheduler.step(start_epoch)if utils.is_primary(args):_logger.info(f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')results = []try:for epoch in range(start_epoch, num_epochs):if hasattr(dataset_train, 'set_epoch'):dataset_train.set_epoch(epoch)elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):loader_train.sampler.set_epoch(epoch)train_metrics = train_one_epoch(epoch,model,loader_train,optimizer,train_loss_fn,args,lr_scheduler=lr_scheduler,saver=saver,output_dir=output_dir,amp_autocast=amp_autocast,loss_scaler=loss_scaler,model_ema=model_ema,mixup_fn=mixup_fn,num_updates_total=num_epochs * updates_per_epoch,)if args.distributed and args.dist_bn in ('broadcast', 'reduce'):if utils.is_primary(args):_logger.info("Distributing BatchNorm running means and vars")utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')if loader_eval is not None:eval_metrics = validate(model,loader_eval,validate_loss_fn,args,device=device,amp_autocast=amp_autocast,)if model_ema is not None and not args.model_ema_force_cpu:if args.distributed and args.dist_bn in ('broadcast', 'reduce'):utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')ema_eval_metrics = validate(model_ema,loader_eval,validate_loss_fn,args,device=device,amp_autocast=amp_autocast,log_suffix=' (EMA)',)eval_metrics = ema_eval_metricselse:eval_metrics = Noneif output_dir is not None:lrs = [param_group['lr'] for param_group in optimizer.param_groups]utils.update_summary(epoch,train_metrics,eval_metrics,filename=os.path.join(output_dir, 'summary.csv'),lr=sum(lrs) / len(lrs),write_header=best_metric is None,log_wandb=args.log_wandb and has_wandb,)if eval_metrics is not None:latest_metric = eval_metrics[eval_metric]else:latest_metric = train_metrics[eval_metric]if saver is not None:# save proper checkpoint with eval metricbest_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric)if lr_scheduler is not None:# step LR for next epochlr_scheduler.step(epoch + 1, latest_metric)results.append({'epoch': epoch,'train': train_metrics,'validation': eval_metrics,})except KeyboardInterrupt:passresults = {'all': results}if best_metric is not None:results['best'] = results['all'][best_epoch - start_epoch]_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))print(f'--result\n{json.dumps(results, indent=4)}')def train_one_epoch(epoch,model,loader,optimizer,loss_fn,args,device=torch.device('cuda'),lr_scheduler=None,saver=None,output_dir=None,amp_autocast=suppress,loss_scaler=None,model_ema=None,mixup_fn=None,num_updates_total=None,
):if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:if args.prefetcher and loader.mixup_enabled:loader.mixup_enabled = Falseelif mixup_fn is not None:mixup_fn.mixup_enabled = Falsesecond_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_orderhas_no_sync = hasattr(model, "no_sync")update_time_m = utils.AverageMeter()data_time_m = utils.AverageMeter()losses_m = utils.AverageMeter()model.train()accum_steps = args.grad_accum_stepslast_accum_steps = len(loader) % accum_stepsupdates_per_epoch = (len(loader) + accum_steps - 1) // accum_stepsnum_updates = epoch * updates_per_epochlast_batch_idx = len(loader) - 1last_batch_idx_to_accum = len(loader) - last_accum_stepsdata_start_time = update_start_time = time.time()optimizer.zero_grad()update_sample_count = 0for batch_idx, (input, target) in enumerate(loader):last_batch = batch_idx == last_batch_idxneed_update = last_batch or (batch_idx + 1) % accum_steps == 0update_idx = batch_idx // accum_stepsif batch_idx >= last_batch_idx_to_accum:accum_steps = last_accum_stepsif not args.prefetcher:input, target = input.to(device), target.to(device)if mixup_fn is not None:input, target = mixup_fn(input, target)if args.channels_last:input = input.contiguous(memory_format=torch.channels_last)# multiply by accum steps to get equivalent for full updatedata_time_m.update(accum_steps * (time.time() - data_start_time))def _forward():with amp_autocast():output = model(input)loss = loss_fn(output, target)if accum_steps > 1:loss /= accum_stepsreturn lossdef _backward(_loss):if loss_scaler is not None:loss_scaler(_loss,optimizer,clip_grad=args.clip_grad,clip_mode=args.clip_mode,parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),create_graph=second_order,need_update=need_update,)else:_loss.backward(create_graph=second_order)if need_update:if args.clip_grad is not None:utils.dispatch_clip_grad(model_parameters(model, exclude_head='agc' in args.clip_mode),value=args.clip_grad,mode=args.clip_mode,)optimizer.step()if has_no_sync and not need_update:with model.no_sync():loss = _forward()_backward(loss)else:loss = _forward()_backward(loss)if not args.distributed:losses_m.update(loss.item() * accum_steps, input.size(0))update_sample_count += input.size(0)if not need_update:data_start_time = time.time()continuenum_updates += 1optimizer.zero_grad()if model_ema is not None:model_ema.update(model, step=num_updates)if args.synchronize_step and device.type == 'cuda':torch.cuda.synchronize()time_now = time.time()update_time_m.update(time.time() - update_start_time)update_start_time = time_nowif update_idx % args.log_interval == 0:lrl = [param_group['lr'] for param_group in optimizer.param_groups]lr = sum(lrl) / len(lrl)if args.distributed:reduced_loss = utils.reduce_tensor(loss.data, args.world_size)losses_m.update(reduced_loss.item() * accum_steps, input.size(0))update_sample_count *= args.world_sizeif utils.is_primary(args):_logger.info(f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} 'f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)] 'f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) 'f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s 'f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) 'f'LR: {lr:.3e} 'f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})')if args.save_images and output_dir:torchvision.utils.save_image(input,os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),padding=0,normalize=True)if saver is not None and args.recovery_interval and ((update_idx + 1) % args.recovery_interval == 0):saver.save_recovery(epoch, batch_idx=update_idx)if lr_scheduler is not None:lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)update_sample_count = 0data_start_time = time.time()# end forif hasattr(optimizer, 'sync_lookahead'):optimizer.sync_lookahead()return OrderedDict([('loss', losses_m.avg)])def validate(model,loader,loss_fn,args,device=torch.device('cuda'),amp_autocast=suppress,log_suffix=''
):batch_time_m = utils.AverageMeter()losses_m = utils.AverageMeter()top1_m = utils.AverageMeter()top5_m = utils.AverageMeter()model.eval()end = time.time()last_idx = len(loader) - 1with torch.no_grad():for batch_idx, (input, target) in enumerate(loader):last_batch = batch_idx == last_idxif not args.prefetcher:input = input.to(device)target = target.to(device)if args.channels_last:input = input.contiguous(memory_format=torch.channels_last)with amp_autocast():output = model(input)if isinstance(output, (tuple, list)):output = output[0]# augmentation reductionreduce_factor = args.ttaif reduce_factor > 1:output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)target = target[0:target.size(0):reduce_factor]loss = loss_fn(output, target)acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))if args.distributed:reduced_loss = utils.reduce_tensor(loss.data, args.world_size)acc1 = utils.reduce_tensor(acc1, args.world_size)acc5 = utils.reduce_tensor(acc5, args.world_size)else:reduced_loss = loss.dataif device.type == 'cuda':torch.cuda.synchronize()losses_m.update(reduced_loss.item(), input.size(0))top1_m.update(acc1.item(), output.size(0))top5_m.update(acc5.item(), output.size(0))batch_time_m.update(time.time() - end)end = time.time()if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):log_name = 'Test' + log_suffix_logger.info(f'{log_name}: [{batch_idx:>4d}/{last_idx}] 'f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f}) 'f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f}) 'f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f}) 'f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})')metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])return metricsif __name__ == '__main__':main()
1.5. 支持的网络
可选的在train.py
里修改使用的网络:
group.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL',help='Name of model to train (default: "resnet50")')
支持的网络(根据名字选择)如下:
# 本示例说明获取timm支持的所有模型以及所有有预训练参数模型的获取方法
import timm# 1. timm支持的所有模型
# timm supports a wide variety of pretrained and non-pretrained models for number of Image based tasks.
# The list_models function returns a list of models ordered alphabetically that are supported by timm.
supported_models = timm.list_models()print(supported_models[:5])
print(len(supported_models))# 2. timm支持的所有有预训练参数的模型
# To list all the models that have pretrained weights, timm provides a convenience parameter pretrained that could be passed in list_models function as below.
supported_pretrained_models = timm.list_models(pretrained=True)
print(supported_pretrained_models[:5])
print(len(supported_pretrained_models))# 3. 使用通配符查看特定模型,如timm.list_models('*resne*')
resnet_models = timm.list_models('*resnet*')
print(f"ResNet系列模型:{resnet_models}")vit_models = timm.list_models('*vit*')
print(f"vit系列模型:{vit_models}")
2. 源码解读
2.1 pytorch-image-models/timm/models
2.1.1 模型结构文件
模型结构文件是指提供网络结构的文件,即pytorch-image-models/timm/models
下面的.py
文件中包含@register_model
(即用来提供网络模型的.py
)的文件。
核心功能:定义所有支持的模型架构(如 ResNet、ViT、EfficientNet 等),并提供统一的模型创建接口。
每个文件对应一种模型家族(如 resnet.py、vision_transformer.py、efficientnet.py),定义网络结构和预训练权重配置。
2.1.1.1 模型结构文件
timm
的 _registry.py
文件中 register_model
函数实现了模型的自动注册机制,使得import timm
之后在 timm.create_model
之前,所有模型就都被存储到了 _model_entrypoints
中。
工作机制解释:
- 装饰器:
register_model
本质上是一个装饰器。装饰器的作用是修改函数的功能,在这里,它为模型构建函数添加了注册功能。 - 模型注册: 当用
@register_model
装饰一个模型构建函数 (例如resnet50
) 时,装饰器会将该模型的名称 (resnet50
) 和对应的构建函数 (resnet50
) 存储到_model_entrypoints
字典中。 - 导入即注册: 由于
timm
库中的模型文件 (例如resnet.py
) 在导入timm
库时会被自动执行,而这些模型文件中都使用了@register_model
装饰器,因此模型在导入timm
库时就被自动注册了。 create_model
调用: 当你调用timm.create_model('resnet50')
时,create_model
函数会从_model_entrypoints
字典中查找名为resnet50
的模型构建函数,并调用该函数来创建模型实例。
总结:
register_model
函数通过装饰器机制,实现了模型的自动注册。在导入timm
库时,所有模型都会被自动注册到_model_entrypoints
字典中,使得timm.create_model
能够方便地创建模型实例。
优点:
- 集中管理: 将所有模型的注册信息集中存储,方便管理和维护。
- 易于扩展: 添加新模型只需定义模型构建函数并使用
@register_model
装饰即可,无需修改其他代码。 - 用户友好: 用户只需使用
timm.create_model
即可创建模型实例,无需关心模型的注册细节。
2.1.2 非模型结构文件
非模型结构文件是指不提供网络结构的文件,即pytorch-image-models/timm/models
下面的.py
文件中不包含@register_model
(即不是用来提供网络模型的.py
)的文件。这些文件的名称是什么,作用是什么?
After inspecting the .py
files under timm/models
, I found the following files that do not contain @register_model
:
-
__init__.py
: This file imports all the other model definition files and any helper modules/classes used in the model definitions. This makes it easier to import models and related functions from a single location (timm.models
). -
_builder.py
: This file contains helper functions for building models with configurations. It includes functions for loading pretrained weights, adapting models for feature extraction, and handling various model configurations. -
_efficientnet_blocks.py
: This file defines the building blocks for EfficientNet-like models, such as Squeeze-and-Excitation blocks, different convolutional blocks, and attention modules. -
_efficientnet_builder.py
: This file provides the logic for constructing EfficientNet models based on architecture definitions. It handles scaling of model depths and widths, block configurations, and weight initialization. -
_factory.py
: This file contains functions for creating models based on their names. It includes logic for parsing model names, handling pretrained weight loading, and setting model attributes. -
_features.py
: This file defines helper classes and functions for extracting features from models. It includes tools for specifying feature extraction points, handling feature metadata, and creating feature extraction wrappers. -
_features_fx.py
: This file contains PyTorch FX based feature extraction helpers. It provides tools for tracing model graphs, identifying feature extraction nodes, and creating feature extraction wrappers using FX. -
_helpers.py
: This file contains helper functions for model creation, weight loading, and state dictionary manipulation. It includes functions for cleaning state dictionaries, loading checkpoints, and remapping state dictionaries. -
_hub.py
: This file contains helper functions for interacting with the Hugging Face Hub. It includes functions for downloading and saving models and configurations, as well as pushing models to the Hub. -
_manipulate.py
: This file contains helper functions for manipulating model parameters and modules. It includes functions for applying functions to named modules, grouping modules and parameters, and flattening nested modules. -
_pretrained.py
: This file defines data classes and functions for handling pretrained model configurations. It includes classes for storing pretrained weight URLs, input configurations, and other metadata. -
_prune.py
: This file contains helper functions for model pruning. It includes functions for extracting and modifying specific layers in a model, as well as adapting models based on pruning configurations.
These files play crucial roles in defining, building, and manipulating models within the timm
library, but they don’t directly provide network architectures like the files that use @register_model
.
2.1.2.1 进一步说明以下.py
文件的作用
- 比较一下_efficientnet_blocks.py和_efficientnet_builder.py的不同作用?
- 比较一下_features.py和_features_fx.py的不同作用?
- 讲一下_manipulate.py的具体作用及用法。
- 讲一下_prune.py的具体作用及用法。
1. _efficientnet_blocks.py
vs. _efficientnet_builder.py
-
_efficientnet_blocks.py
: This module defines the individual building blocks that make up EfficientNet and related models. These blocks are often variations of inverted residual blocks, depthwise separable convolutions, or attention mechanisms. It focuses on the micro-architecture of the models. -
_efficientnet_builder.py
: This module handles the construction of complete EfficientNet models by assembling the blocks defined in_efficientnet_blocks.py
. It takes care of:- Decoding architecture definitions (strings that specify the model structure).
- Scaling model depth and width based on scaling coefficients.
- Handling strides and dilations to achieve desired output strides.
- Selecting feature extraction points within the model.
- Weight initialization.
In essence, _efficientnet_blocks.py
provides the ingredients, while _efficientnet_builder.py
provides the recipe and cooking instructions for creating EfficientNet models.
2. _features.py
vs. _features_fx.py
-
_features.py
: This module provides tools for extracting intermediate features from models using traditional PyTorch methods like:- Hooking into forward passes of specific modules.
- Rewriting the model to return intermediate activations.
- Accessing feature information (channel counts, strides) using metadata.
-
_features_fx.py
: This module leverages the PyTorch FX framework for feature extraction. FX allows for symbolic tracing of model execution, making it easier to:- Identify feature extraction points within the model graph.
- Create feature extraction wrappers without modifying the original model.
- Optimize feature extraction for specific use cases.
FX generally provides a more flexible and efficient way to extract features, especially for complex models.
3. _manipulate.py
This module offers a collection of helper functions for manipulating model parameters and modules. Common use cases include:
- Parameter Grouping: Grouping parameters based on regular expressions or custom matchers. This is useful for applying different learning rates or optimization strategies to different parts of the model.
- Module Grouping: Grouping modules for similar purposes as parameter grouping.
- Module Flattening: Flattening nested modules (e.g.,
nn.Sequential
inside anothernn.Sequential
) to simplify model structure. - Applying Functions: Applying a given function to all named modules or parameters in the model.
- Gradient Checkpointing: Memory-efficient training technique that trades compute for memory by recomputing activations during the backward pass.
Usage Examples:
# Group parameters based on module names
grouped_params = group_parameters(model, {'stem': '^conv1', 'blocks': '^layer'})# Apply weight initialization to specific modules
named_apply(init_weights, model, depth_first=False, include_root=True)# Flatten nested sequential modules
flattened_modules = flatten_modules(model.named_modules())
4. _prune.py
This module provides tools for model pruning, which involves removing less important connections or neurons to reduce model size and complexity. Key functions include:
- Layer Extraction: Extracting specific layers or modules from the model by name or path.
- Layer Modification: Setting or replacing layers within the model.
- Model Adaptation: Adapting a model’s structure based on pruning configurations, such as reducing channel counts or removing layers.
Usage Examples:
# Prune a specific convolutional layer
pruned_conv = prune_conv_layer(original_conv, sparsity=0.5)
set_layer(model, 'blocks.2.1.conv', pruned_conv)# Adapt a model from a pruning configuration file
adapted_model = adapt_model_from_file(model, 'resnet50_pruned')
These helper modules streamline common model manipulation tasks, making it easier to customize and optimize models within the timm
library.
2.2 pytorch-image-models/timm/layers
2.2.1 timm
自定义layers
The timm/layers
directory contains a variety of modules and functions that serve as building blocks for the models in the timm
library. Here’s a breakdown of the key roles of each file:
Core Layers & Functions:
activations.py
: Provides a collection of activation functions (ReLU, Swish, Mish, etc.) with a consistent interface for easy swapping and potential JIT scripting/export.activations_me.py
: Offers memory-efficient versions of some activations using custom autograd, but these are not compatible with JIT or ONNX export.adaptive_avgmax_pool.py
: Implements adaptive average and max pooling layers, including combinations and concatenation.attention2d.py
: Defines 2D attention mechanisms, including multi-query attention and spatial attention with downsampling.attention_pool.py
: Implements attention pooling with a latent query, useful for global feature aggregation.attention_pool2d.py
: Provides 2D attention pooling mechanisms, including those with learned and rotary position embeddings.blur_pool.py
: Implements BlurPool, an anti-aliasing technique that combines blurring and downsampling.bottleneck_attn.py
: Defines the Bottleneck Attention module, a type of self-attention used in Bottleneck Transformers.cbam.py
: Implements the Convolutional Block Attention Module (CBAM), a combination of channel and spatial attention.classifier.py
: Provides classifier heads with pooling, dropout, and fully connected layers.cond_conv2d.py
: Implements Conditionally Parameterized Convolutions (CondConv), which dynamically adjust filters based on input.config.py
: Manages global configuration flags for layers, such as JIT scripting, ONNX export, and fused attention settings.conv2d_same.py
: Offers “SAME” padding convolution layers, similar to TensorFlow’s padding behavior.conv_bn_act.py
: Combines convolution, batch normalization, and activation into a single module.create_act.py
: Factory function for creating activation functions and layers based on names.create_attn.py
: Factory function for creating attention modules based on names.create_conv2d.py
: Factory function for creating different types of 2D convolutions (standard, mixed, conditional).create_norm.py
: Factory function for creating normalization layers (BatchNorm, GroupNorm, LayerNorm, etc.)create_norm_act.py
: Factory function for creating combined normalization and activation layers.drop.py
: Implements DropBlock and DropPath (Stochastic Depth) regularization techniques.eca.py
: Defines the Efficient Channel Attention (ECA) module.evo_norm.py
: Implements EvoNorm layers, a type of normalization.fast_norm.py
: Provides optimized implementations of GroupNorm and LayerNorm for mixed precision training.filter_response_norm.py
: Implements Filter Response Normalization (FRN) layers.format.py
: Utilities for handling different tensor formats (NCHW, NHWC, etc.).gather_excite.py
: Defines the Gather-Excite attention module.global_context.py
: Implements the Global Context (GC) attention block.grid.py
: Functions for generating N-dimensional grids.grn.py
: Implements Global Response Normalization (GRN) layer.halo_attn.py
: Defines the Halo Attention module.helpers.py
: Various helper functions for layers (e.g.,make_divisible
).hybrid_embed.py
: Provides layers for embedding CNN feature maps into a transformer-compatible format.inplace_abn.py
: Implements Inplace Activated Batch Normalization (InplaceABN).interpolate.py
: Interpolation utilities for layers.lambda_layer.py
: Defines the Lambda Layer, an attention-like mechanism.layer_scale.py
: Implements LayerScale, a scaling factor applied to layer outputs.linear.py
: A modified linear layer with support for mixed precision training.median_pool.py
: Implements a median pooling layer.mixed_conv2d.py
: Defines MixedConv2d, which uses multiple kernel sizes in a single convolution.mlp.py
: Implements Multi-Layer Perceptrons (MLPs) with various configurations.ml_decoder.py
: Provides an ML decoder head.non_local_attn.py
: Implements Non-Local Attention blocks.norm.py
: Normalization layers with fast norm options.norm_act.py
: Combines normalization and activation layers into a single module.padding.py
: Helper functions for padding operations, including “SAME” padding.patch_dropout.py
: Implements PatchDropout, a type of dropout.patch_embed.py
: Provides layers for converting images to patches for transformer input.pool2d_same.py
: Pooling layers with “SAME” padding.pos_embed.py
: Utilities for absolute position embeddings.pos_embed_rel.py
: Modules and functions for relative position embeddings.pos_embed_sincos.py
: Implements sin-cos, fourier, and rotary position embeddings.selective_kernel.py
: Defines Selective Kernel Convolution and Attention modules.separable_conv.py
: Implements depthwise separable convolutions.space_to_depth.py
: Provides space-to-depth and depth-to-space operations.split_attn.py
: Implements Split Attention (Splat) used in ResNeSt models.split_batchnorm.py
: Defines SplitBatchNorm, which splits the input batch into multiple parts for separate normalization.squeeze_excite.py
: Implements Squeeze-and-Excitation (SE) attention blocks.std_conv.py
: Provides convolutions with weight standardization.test_time_pool.py
: Implements test-time pooling for improved inference accuracy.trace_utils.py
: Helper functions for symbolic tracing.typing.py
: Type hints for layers.weight_init.py
: Weight initialization functions.
2.2.2 timm
为什么要自定义layers
?
pytorch
中已经有了构建网络结构的层定义了,比如relu
,conv
、pooling
等,为什么还需要中pytorch-image-models/timm/layers
的这些层的定义呢?
While PyTorch offers a solid foundation for building neural networks, timm
provides its own set of layer definitions for several reasons:
1. Enhanced Functionality & Flexibility:
- Advanced Layers:
timm
includes layers not readily available in core PyTorch, such as:- Attention Mechanisms: Various attention blocks (CBAM, ECA, Non-Local Attention) for capturing relationships between features.
- Pooling Variations: Adaptive average-max pooling, BlurPool for anti-aliasing.
- Normalization Layers: EvoNorm, Filter Response Norm, InplaceABN.
- Specialized Convolutions: MixedConv2d (multiple kernel sizes), CondConv2d (dynamic kernels), “SAME” padded convolutions.
- Customizable Building Blocks:
timm
layers often offer more configuration options than their PyTorch counterparts. For example,ConvNormAct
combines convolution, normalization, and activation in a single module with flexible choices for each component. - Modular Design:
timm
layers are designed to be modular and easily combined, promoting code reusability and experimentation with different architectures.
2. Optimization & Efficiency:
- Memory-Efficient Activations:
timm
provides memory-efficient versions of some activations (activations_me.py
) using custom autograd, which can be beneficial for large models or limited memory. - Fast Normalization: Optimized implementations of GroupNorm and LayerNorm (
fast_norm.py
) for improved performance in mixed precision training. - “SAME” Padding: Efficient implementation of TensorFlow-like “SAME” padding for convolutions and pooling.
3. Consistency & Compatibility:
- Unified Interface:
timm
layers often provide a more consistent interface, such as ensuring the channel dimension is always the first argument in normalization layers. - TorchScript & ONNX Support: Many
timm
layers are designed with TorchScript and ONNX export in mind, making it easier to deploy models.
4. Research & Experimentation:
- Novel Layers:
timm
incorporates layers from recent research papers, allowing for quick experimentation with new architectures and ideas. - Extensibility: The modular design of
timm
makes it easy to add or modify layers to explore new research directions.
In summary, while PyTorch provides the essentials, timm
builds upon them by offering a richer set of layers, optimized implementations, and greater flexibility for building and experimenting with state-of-the-art models.
2.3 pytorch-image-models/timm/loss
The timm/loss
directory contains various loss functions that can be used for training deep learning models, particularly those focused on image-related tasks. Here’s a breakdown of the files and their purposes:
-
asymmetric_loss.py
: Implements the Asymmetric Loss function, designed to address class imbalance in multi-label classification problems. It assigns different weights to positive and negative samples to improve learning in the presence of under-represented classes. -
binary_cross_entropy.py
: Provides a Binary Cross Entropy (BCE) loss with additional features like label smoothing, target thresholding, and optional one-hot conversion for dense targets. Useful for binary or multi-label classification tasks. -
cross_entropy.py
: Implements two variations of cross-entropy loss:LabelSmoothingCrossEntropy
: Applies label smoothing to the standard cross-entropy loss, preventing overconfidence and improving generalization.SoftTargetCrossEntropy
: Calculates cross-entropy with soft targets (probability distributions) instead of hard labels. Useful for distillation or other knowledge transfer tasks.
-
jsd.py
: Implements the JSD (Jensen-Shannon Divergence) Cross Entropy loss. This loss combines cross-entropy with a Jensen-Shannon Divergence term, which encourages the model to produce consistent predictions across augmented versions of the same input. Useful for improving robustness and uncertainty estimation.
In summary, the timm/loss
module provides a collection of loss functions that go beyond the standard PyTorch offerings. These losses address issues like class imbalance, overconfidence, and robustness, making them valuable tools for training image models.
2.4 pytorch-image-models/timm/optim
2.4.1 优化器工厂optim_factory.py
optim_factory.py
: Contains functions for creating and registering optimizers, handling parameter groups, and applying weight decay and layer decay.
2.4.2 优化器
除了optim_factory.py
之外都是timm
实现的优化器。
The timm/optim
directory houses a collection of optimizers that extend beyond the standard optimizers available in PyTorch’s torch.optim
module. These optimizers incorporate various enhancements and modifications to improve training performance, convergence speed, and memory efficiency.
2.5 pytorch-image-models/timm/scheduler
2.5.1 scheduler.py
定义了一个名为Scheduler
的基类,目的是用于实现优化器参数调度器(例如学习率调度器)。它与 PyTorch 的内建调度器不同,强调在每个训练周期结束时(即每个 epoch 或者每次优化器更新之后)动态调整优化器的参数。它支持噪声的引入,目的是在训练过程中增加一些随机性,以避免模型过拟合。调度器的核心思想是通过调用 step 或 step_update 来在每个 epoch 或每次优化器更新时调整学习率等参数。
2.5.2 LR schedule
除了scheduler.py
和 scheduler_factory.py
之外的.py
文件都是学习率调度方法,继承自scheduler.py
中的Scheduler
的基类。scheduler.py
和 scheduler_factory.py
是通过具体的学习率调度算法联系起来的,比如CosineLRScheduler
。
2.5.3 scheduler_factory.py
定义了一个学习率调度器的工厂函数 (create_scheduler
) 以及一些辅助函数和参数处理逻辑,旨在根据配置文件或命令行参数来创建适当的学习率调度器。代码支持多种类型的调度器(如cosine
, step,
multistep,
plateau,
poly` 等),并允许使用不同的调度策略、噪声调整、循环学习率等,它还支持学习率预热、噪声扰动、循环学习率等高级功能,能够帮助提升训练过程的灵活性和效果。通过调整学习率调度器的行为,可以对模型训练进行精细控制,帮助优化训练过程。
2.6 pytorch-image-models/timm/data
The timm/data
directory contains modules and functions that handle various aspects of data loading, preprocessing, augmentation, and overall data pipeline management for training and evaluating image models. Here’s a categorized summary:
Data Augmentation:
auto_augment.py
: Implements various automatic augmentation strategies, including:- AutoAugment: Learns augmentation policies from data.
- RandAugment: Applies a series of random augmentations with varying magnitudes.
- AugMix: Combines multiple augmentations with different weights.
mixup.py
: Implements Mixup and Cutmix augmentation techniques, where images and labels are mixed to improve model robustness and generalization.random_erasing.py
: Implements Random Erasing, an augmentation that randomly erases rectangular regions of an image.
Dataset and Sampler:
-
dataset.py
: DefinesImageDataset
(for standard image folders and tar files) andIterableImageDataset
(for iterable datasets like TFDS and WDS). Also includesAugMixDataset
for applying AugMix. -
distributed_sampler.py
: Provides samplers for distributed training:OrderedDistributedSampler
: Ensures each process gets a distinct subset of data.RepeatAugSampler
: Allows different augmentations of the same sample to be seen by different processes.
-
dataset_factory.py
: This file acts as a factory for creating various types of datasets. It provides thecreate_dataset
function, which can instantiate datasets from different sources, including:- Image folders and tar files: Uses
timm
’s ownImageDataset
. - Torchvision datasets: Leverages datasets like ImageNet, CIFAR10, MNIST directly from
torchvision
. - Hugging Face Datasets: Integrates with Hugging Face Datasets for both standard (
HFDS
) and iterable (HFIDS
) datasets. - TensorFlow Datasets (TFDS): Uses
IterableImageDataset
to handle TFDS datasets. - WebDataset (WDS): Also uses
IterableImageDataset
for WDS datasets.
The factory function handles various dataset-specific configurations, such as splits, image modes, and download options.
- Image folders and tar files: Uses
-
dataset_info.py
: This file defines theDatasetInfo
abstract base class, which provides a common interface for accessing information about datasets, such as:- Number of classes.
- Label names and descriptions.
- Mappings between indices and label names.
It also includes aCustomDatasetInfo
class for easily creating dataset information for custom datasets.
Transforms:
transforms.py
: Defines a set of image transformations, including:RandomResizedCropAndInterpolation
: Randomly crops and resizes with various interpolation modes.CenterCropOrPad
: Center crops or pads an image to a target size.RandomCropOrPad
: Randomly crops or pads.ResizeKeepRatio
: Resizes while maintaining aspect ratio.TrimBorder
: Trims a border from an image.- And more…
transforms_factory.py
: Provides factory functions (transforms_noaug_train
,transforms_imagenet_train
,transforms_imagenet_eval
) to create sets of transforms optimized for different training and evaluation scenarios.
Data Configuration and Utilities:
config.py
: Provides functions (resolve_data_config
,resolve_model_data_config
) to determine image size, mean, standard deviation, and other data processing parameters based on model and dataset configurations.constants.py
: Defines constants like default crop percentage, ImageNet mean and standard deviation, etc.imagenet_info.py
: Provides a class (ImageNetInfo
) to access information about ImageNet dataset subsets, such as class labels, descriptions, and mappings.real_labels.py
: Implements an evaluator (RealLabelsImagenet
) to assess model performance using ImageNet’s “real” labels.
Data Loading:
loader.py
: Contains functions for creating data loaders, includingcreate_loader
which handles various data loading aspects like:- Batch size, shuffling, and num_workers.
- Distributed training with samplers.
- Collation functions (including
fast_collate
for optimized collation). - Prefetching data to the GPU using
PrefetchLoader
.
This comprehensive set of modules and functions within timm/data
streamlines data loading, preprocessing, and augmentation, making it easier to train and evaluate state-of-the-art image models.
TensorFlow related:
tf_preprocessing.py
: This file enables the use of TensorFlow’s image preprocessing pipeline within PyTorch transforms. It defines theTfPreprocessTransform
class, which leverages TensorFlow’s preprocessing functions (likepreprocess_for_train
andpreprocess_for_eval
) to perform operations such as:- Distorted bounding box cropping.
- Random cropping and flipping.
- Center cropping.
This allows for consistency when evaluating models that were originally trained using TensorFlow’s preprocessing.
2.6.1 pytorch-image-models/timm/data/readers
Dataset Readers:
-
reader_image_folder.py
: This file defines theReaderImageFolder
class, which enables reading images from folders. It scans folders recursively, infers labels from the folder structure, and provides a mapping between class names and indices. -
reader_image_in_tar.py
: This file defines theReaderImageInTar
class, designed for reading images from tar files. It handles single tar files, folders of tar files, and even nested tar files. It also manages class mappings and caching of tar file information. -
reader_image_tar.py
: This file definesReaderImageTar
, which reads images from a single tar file. It’s similar toReaderImageInTar
but with more limited functionality. It’s likely to be deprecated in the future. -
reader_tfds.py
: This file provides theReaderTfds
class, which wraps TensorFlow Datasets (TFDS) for use in PyTorch. It handles dataset loading, shuffling, batching, and decoding of image samples from TFDS datasets. -
reader_hfds.py
: This file defines theReaderHfds
class, which wraps Hugging Face Datasets (HFDS) for use in PyTorch. It handles loading and decoding image samples from HFDS datasets. -
reader_hfids.py
: This file defines theReaderHfids
class, which wraps Hugging Face Iterable Datasets (HFIDS) for use in PyTorch. It handles streaming and decoding image samples from HFIDS datasets. -
reader_wds.py
: This file provides theReaderWds
class, which wraps WebDataset (WDS) for use in PyTorch. It handles loading and decoding samples from WDS, including support for sharding and distributed training.
Utilities and Helpers:
-
class_map.py
: This file contains theload_class_map
function, which loads a class map from a text file or pickle file. This map is used to associate class names with indices. -
img_extensions.py
: This file manages the supported image file extensions. It provides functions likeget_img_extensions
,is_img_extension
, andset_img_extensions
to control which file types are recognized as images. -
shared_count.py
: This file defines theSharedCount
class, which provides a way to share a counter across multiple processes. This is useful for things like epoch tracking in distributed training. -
reader.py
: This file defines the baseReader
class, which provides a common interface for all dataset readers. -
reader_factory.py
: This file contains thecreate_reader
factory function, which instantiates the appropriate reader class based on the provided configuration parameters.
2.7 pytorch-image-models/timm/utils
I’ve diligently reviewed the files in the pytorch-image-models/timm/utils
directory and I’m prepared to give you an accurate and organized summary.
The timm/utils
directory contains a collection of utility modules and functions that support various aspects of model training, evaluation, and manipulation. Here’s a breakdown:
Model and Optimization:
agc.py
: Implements Adaptive Gradient Clipping (AGC), a technique to clip gradients based on the unit-wise norm of parameters, preventing excessive weight updates.clip_grad.py
: Provides a function (dispatch_clip_grad
) to apply different gradient clipping methods, including norm-based clipping, value-based clipping, and AGC.model_ema.py
: Implements Exponential Moving Average (EMA) of model weights, a technique to maintain a smoothed version of the model’s parameters for better generalization. Includes multiple versions of EMA with varying performance and compatibility.
CUDA and AMP:
cuda.py
: Provides utilities for mixed precision training with Automatic Mixed Precision (AMP). Includes both the older ApexScaler and the newer NativeScaler for handling gradient scaling and unscaling.
Training and Checkpointing:
checkpoint_saver.py
: Implements aCheckpointSaver
class to manage saving and loading model checkpoints, including tracking the best performing checkpoints and handling recovery checkpoints.decay_batch.py
: Provides functions for decaying the batch size during training, which can be useful for improving stability and generalization.
Distributed Training:
distributed.py
: Contains utilities for distributed training, including functions to:- Initialize distributed devices and processes.
- Distribute batch normalization statistics across devices.
- Reduce tensors across processes.
- Check for primary (rank 0) process.
Logging and Metrics:
log.py
: Provides functions for setting up logging, including a custom formatter (FormatterNoInfo
) that omits ‘INFO’ level logging to the console.metrics.py
: Defines anAverageMeter
class for tracking average values and anaccuracy
function for calculating top-k accuracy.summary.py
: Implements theupdate_summary
function to log training and evaluation metrics to a CSV file and optionally to Weights & Biases (wandb).
Miscellaneous:
attention_extract.py
: Implements theAttentionExtract
class, which allows extracting attention maps from models using either PyTorch FX or hooks.jit.py
: Provides functions to configure the JIT (Just-In-Time) compiler for scripting and tracing models.misc.py
: Contains miscellaneous helper functions, includingnatural_key
for natural sorting of strings and functions for adding boolean arguments to an argument parser.model.py
: Provides utilities for working with models, including:unwrap_model
: Gets the underlying model from wrappers likeDataParallel
orModelEma
.get_state_dict
: Extracts the state dictionary from a model.freeze
andunfreeze
: Freeze or unfreeze model parameters.reparameterize_model
: Converts a model to a deployable form by fusing layers and/or reparameterizing modules.
onnx.py
: Provides functions for exporting models to ONNX format and verifying the exported models.random.py
: Provides arandom_seed
function to set seeds fortorch
,numpy
, andrandom
modules.
2.7 pytorch-image-models/tests
Model Tests:
test_models.py
: This file contains comprehensive tests for various aspects of the models in thetimm
library. It covers a wide range of model architectures and includes tests for:- Model inference: Checks that models produce the expected output for given inputs.
- Forward pass: Verifies the basic forward pass functionality of models.
- Backward pass: Checks that gradients are calculated correctly during backpropagation.
- Default configurations: Validates the default configurations provided for each model.
- Feature extraction: Tests the feature extraction capabilities of models.
- TorchScript compatibility: Checks if models can be successfully scripted using TorchScript.
- FX tracing: Tests the symbolic tracing of models using PyTorch FX.
Layer Tests
test_layers.py
: This file focuses on testing individual layers and modules used within thetimm
models. It includes tests for:- Activation layers: Verifies the functionality of various activation functions, including their gradients.
- Attention layers: Tests attention mechanisms like
Attention2d
andMultiQueryAttentionV2
.
Optimizer Tests
test_optim.py
: This file contains tests for the optimizers implemented in thetimm/optim
directory. It covers a variety of optimizers and includes tests for:- Basic optimization cases: Checks the basic functionality of optimizers.
- State dictionary: Verifies that the optimizer’s state can be saved and loaded correctly.
- Parameter groups: Tests the creation of parameter groups with weight decay and layer decay.
Utility Tests
test_utils.py
: This file tests various utility functions and classes provided in thetimm/utils
directory. It includes tests for:- Model freezing and unfreezing: Checks the
freeze
andunfreeze
functions that control parameter freezing. - Activation statistics: Tests the
ActivationStatsHook
and related functions for extracting activation statistics. - Model reparameterization: Verifies the
reparameterize_model
function that converts models to a deployable form. - State dictionary: Tests the
get_state_dict
function with custom unwrap functions.
- Model freezing and unfreezing: Checks the
2.8 pytorch-image-models/*.py
Training and Evaluation:
-
train.py
: This is the primary script for training models on ImageNet or similar datasets. It provides a comprehensive training loop with various options for:- Data loading and augmentation.
- Model creation and initialization.
- Optimizer and learning rate scheduler selection.
- Loss function configuration.
- Mixed precision training with AMP.
- Distributed training.
- Model checkpointing and logging.
-
validate.py
: This script is used to evaluate trained models or pretrained models on ImageNet or similar datasets. It provides options for:- Data loading and preprocessing.
- Model creation and loading from checkpoints.
- Inference with or without test-time augmentation.
- Accuracy and loss calculation.
- Result logging and analysis.
Benchmarking and Profiling:
benchmark.py
: This script benchmarks the inference and training performance of models. It measures metrics like:- Samples per second.
- Step time (forward, backward, and optimization).
- GMACs (multiply-accumulate operations).
- Memory usage.
It supports different precision modes (AMP, FP32, etc.), JIT scripting, and profiling with deepspeed or fvcore.
ONNX Export and Validation:
-
onnx_export.py
: This script exports PyTorch models to ONNX format, allowing them to be used in other frameworks and environments. It supports various options for controlling the export process, such as opset version and dynamic size. -
onnx_validate.py
: This script validates the accuracy and performance of exported ONNX models using the ONNX runtime. It compares the ONNX model’s outputs to the original PyTorch model’s outputs to ensure correctness.
Other Utilities:
-
avg_checkpoints.py
: This script averages the weights of multiple model checkpoints. This can be useful for improving model performance and stability. -
bulk_runner.py
: This script runs thevalidate.py
orbenchmark.py
script in separate processes for each model in a specified list. This allows for bulk validation or benchmarking of multiple models. -
clean_checkpoint.py
: This script cleans a model checkpoint by removing unnecessary data like optimizer state and GPU tensors, making it suitable for sharing and distribution. -
hubconf.py
: This file defines the entry points for thetimm
models in the Hugging Face Hub. -
inference.py
: This script performs inference on a dataset using a specified model and outputs the results in various formats (CSV, JSON, etc.).
These scripts and utilities provide a comprehensive toolkit for training, evaluating, benchmarking, exporting, and managing image models within the timm
library.