当前位置: 首页 > news >正文

从零起步的Kaggle竞赛 - BirdCLEF2025

一个优秀的coder,先从CV工程开始......

首先复制了 LB 0.804- EfficientNet B0 Pytorch Pipeline | Kaggle 这个notebook并尝试提交(Kaggle的notebook中包括参赛者训练好的模型,所以本次提交只能熟悉一下流程而已),ok,0.804,下载了大佬的代码试图在本地修改模型结构并训练。

爬榜日记

20250416:efficientnet训练到loss为0.03左右的时候提交了一次,淦,为什么只有0.510

20250418:

        更换backbone为‘convnext_tiny.in12k_ft_in1k’,并在后面加了一个attention块,loss大约0.023的时候提交,0.596!好耶,马上快及格了。这个backbone是convnext系列中最小的一个模型, 后续会考虑跑大一点的模型试试看

        尝试了maxvit,具体的模型名称是‘maxvit_base_tf_384’,Deepseek说它的模型大小是119M,目前batchsize设置为16,在4090上以20.18G的显存占用训练。看来再大一些的模型就要租显卡了。。。

20250419:

        省流:没必要追求模型参数量,不能 提高效果的同时还会导致timeout。修改了学习率调度策略为warmup+余弦退火。后面要考虑如何利用无标签数据。

        convnext_base在loss为0.02312时得分为0.726。恭喜convnext!
        ok,再大的模型也没必要了,因为发现maxvit_base会超时。(比赛推理时不允许使用GPU,且CPU有时间限制,90min)。那么今天尝试一些轻量级的模型。之前训练时其实没有对val_loss进行记录,现在更新了代码,可以在wandb查看训练情况。

        问DS:音频有底噪,不考虑处理数据的情况下,推荐一些适用于音频分类的模型结构:

tf_efficientnetv2_b0.in1kresnext50_32x4d.a1h_in1k,mobilevit_s.cvnets_in1k

       挨个训练一下。

'tf_efficientnetv2_b0.in1k': 完全可以达到和convnext_base相同的得分...0.729,那我之前花那么多时间占那么大现存跑base是为了什么(错乱)而且推理还慢
'resnext50_32x4d.a1h_in1k'
'mobilevit_s.cvnets_in1k'

以下是大佬的notebook中的代码,可以直接提交。由于其中只含有加载模型推理的代码,我把它命名为test.py。这样后续我就可以从test.py调用模型结构进行训练,无需重复定义模型,而且也只需修改一次模型结构。

导包

import os
import gc
import warnings
import logging
import time
import math
import cv2
from pathlib import Pathimport numpy as np
import pandas as pd
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from tqdm.auto import tqdm# Suppress warnings and limit logging output
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.ERROR)

训练参数


class CFG:"""Configuration class holding all paths and parameters required for the inference pipeline."""test_soundscapes = '/kaggle/input/birdclef-2025/test_soundscapes'submission_csv = '/kaggle/input/birdclef-2025/sample_submission.csv'taxonomy_csv = '/kaggle/input/birdclef-2025/taxonomy.csv'model_path = '/kaggle/input/birdclef-2025-efficientnet-b0' # 从这里上传?# Audio parametersFS = 32000WINDOW_SIZE = 5# Mel spectrogram parametersN_FFT = 1034HOP_LENGTH = 64N_MELS = 136FMIN = 20FMAX = 16000TARGET_SHAPE = (256, 256)model_name = 'efficientnet_b0'in_channels = 1device = 'cpu'# Inference parametersbatch_size = 16use_tta = Falsetta_count = 3threshold = 0.7use_specific_folds = False  # If False, use all found modelsfolds = [0, 1]  # Used only if use_specific_folds is Truedebug = Falsedebug_count = 3

模型定义

可以看到大佬选择用timm库中的模型作为骨干网络(选择的是efficientnet b0)

输出后经过一个池化层

最后经过一个分类头来适配比赛中的分类任务

class BirdCLEFModel(nn.Module):"""Custom neural network model for BirdCLEF-2025 that uses a timm backbone."""def __init__(self, cfg, num_classes):"""Initialize the BirdCLEFModel.:param cfg: Configuration parameters.:param num_classes: Number of output classes."""super().__init__()self.cfg = cfg# Create backbone using timm with specified parameters.self.backbone = timm.create_model(cfg.model_name,pretrained=False,in_chans=cfg.in_channels,drop_rate=0.0,drop_path_rate=0.0)# Adjust final layers based on model typeif 'efficientnet' in cfg.model_name:backbone_out = self.backbone.classifier.in_featuresself.backbone.classifier = nn.Identity()elif 'resnet' in cfg.model_name:backbone_out = self.backbone.fc.in_featuresself.backbone.fc = nn.Identity()else:backbone_out = self.backbone.get_classifier().in_featuresself.backbone.reset_classifier(0, '')self.pooling = nn.AdaptiveAvgPool2d(1)self.feat_dim = backbone_outself.classifier = nn.Linear(backbone_out, num_classes)def forward(self, x):"""Forward pass through the network.:param x: Input tensor.:return: Logits for each class."""features = self.backbone(x)if isinstance(features, dict):features = features['features']# If features are 4D, apply global average pooling.if len(features.shape) == 4:features = self.pooling(features)features = features.view(features.size(0), -1)logits = self.classifier(features)return logits

pipeline定义

那么么有人就要问了:什么是pipeline呢??

class BirdCLEF2025Pipeline:"""Pipeline for the BirdCLEF-2025 inference task.This class organizes the complete inference process:- Loading taxonomy data.- 加载预训练模型文件.- 将音频文件处理成梅尔频谱.- 对每个音频片段进行预测.- 生成提交所需的结果文件.- 对结果文件进行后处理,以 smooth predictions? 这句没看懂"""def __init__(self, cfg):"""根据所给参数初始化inference pipeline.:param cfg: Configuration object with paths and parameters."""self.cfg = cfgself.taxonomy_df = Noneself.species_ids = []self.models = []self._load_taxonomy()def _load_taxonomy(self):"""Load taxonomy data from CSV and extract species identifiers."""print("Loading taxonomy data...")self.taxonomy_df = pd.read_csv(self.cfg.taxonomy_csv)self.species_ids = self.taxonomy_df['primary_label'].tolist()print(f"Number of classes: {len(self.species_ids)}")def audio2melspec(self, audio_data):"""将原始音频文件处理为梅尔频谱:param audio_data: 1D numpy array of audio samples.:return: Normalized mel spectrogram."""if np.isnan(audio_data).any():mean_signal = np.nanmean(audio_data)audio_data = np.nan_to_num(audio_data, nan=mean_signal)mel_spec = librosa.feature.melspectrogram(y=audio_data,sr=self.cfg.FS,n_fft=self.cfg.N_FFT,hop_length=self.cfg.HOP_LENGTH,n_mels=self.cfg.N_MELS,fmin=self.cfg.FMIN,fmax=self.cfg.FMAX,power=2.0)mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8)return mel_spec_normdef process_audio_segment(self, audio_data):"""Process an audio segment to obtain a mel spectrogram with the target shape.:param audio_data: 1D numpy array of audio samples.:return: Processed mel spectrogram as a float32 numpy array."""# Pad audio if it is shorter than the required window size.if len(audio_data) < self.cfg.FS * self.cfg.WINDOW_SIZE:audio_data = np.pad(audio_data,(0, self.cfg.FS * self.cfg.WINDOW_SIZE - len(audio_data)),mode='constant')mel_spec = self.audio2melspec(audio_data)# Resize spectrogram to the target shape if necessary.if mel_spec.shape != self.cfg.TARGET_SHAPE:mel_spec = cv2.resize(mel_spec, self.cfg.TARGET_SHAPE, interpolation=cv2.INTER_LINEAR)return mel_spec.astype(np.float32)def find_model_files(self):"""Find all .pth model files in the specified model directory.:return: List of model file paths."""model_files = []model_dir = Path(self.cfg.model_path)for path in model_dir.glob('**/*.pth'):model_files.append(str(path))return model_filesdef load_models(self):"""Load all found model files and prepare them for ensemble inference.:return: List of loaded PyTorch models."""self.models = []model_files = self.find_model_files()if not model_files:print(f"Warning: No model files found under {self.cfg.model_path}!")return self.modelsprint(f"Found a total of {len(model_files)} model files.")# If specific folds are required, filter the model files.if self.cfg.use_specific_folds:filtered_files = []for fold in self.cfg.folds:fold_files = [f for f in model_files if f"fold{fold}" in f]filtered_files.extend(fold_files)model_files = filtered_filesprint(f"Using {len(model_files)} model files for the specified folds ({self.cfg.folds}).")# Load each model file.for model_path in model_files:try:print(f"Loading model: {model_path}")checkpoint = torch.load(model_path, map_location=torch.device(self.cfg.device))model = BirdCLEFModel(self.cfg, len(self.species_ids))model.load_state_dict(checkpoint['model_state_dict'])model = model.to(self.cfg.device)model.eval()self.models.append(model)except Exception as e:print(f"Error loading model {model_path}: {e}")return self.modelsdef apply_tta(self, spec, tta_idx):"""Apply test-time augmentation (TTA) to the spectrogram.:param spec: Input mel spectrogram.:param tta_idx: Index indicating which TTA to apply.:return: Augmented spectrogram."""if tta_idx == 0:# No augmentation.return specelif tta_idx == 1:# Time shift (horizontal flip).return np.flip(spec, axis=1)elif tta_idx == 2:# Frequency shift (vertical flip).return np.flip(spec, axis=0)else:return specdef predict_on_spectrogram(self, audio_path):"""Process a single audio file and predict species presence for each 5-second segment.:param audio_path: Path to the audio file.:return: Tuple (row_ids, predictions) for each segment."""predictions = []row_ids = []soundscape_id = Path(audio_path).stemtry:print(f"Processing {soundscape_id}")audio_data, _ = librosa.load(audio_path, sr=self.cfg.FS)total_segments = int(len(audio_data) / (self.cfg.FS * self.cfg.WINDOW_SIZE))for segment_idx in range(total_segments):start_sample = segment_idx * self.cfg.FS * self.cfg.WINDOW_SIZEend_sample = start_sample + self.cfg.FS * self.cfg.WINDOW_SIZEsegment_audio = audio_data[start_sample:end_sample]end_time_sec = (segment_idx + 1) * self.cfg.WINDOW_SIZErow_id = f"{soundscape_id}_{end_time_sec}"row_ids.append(row_id)if self.cfg.use_tta:all_preds = []for tta_idx in range(self.cfg.tta_count):mel_spec = self.process_audio_segment(segment_audio)mel_spec = self.apply_tta(mel_spec, tta_idx)mel_spec_tensor = torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)mel_spec_tensor = mel_spec_tensor.to(self.cfg.device)if len(self.models) == 1:with torch.no_grad():outputs = self.models[0](mel_spec_tensor)probs = torch.sigmoid(outputs).cpu().numpy().squeeze()all_preds.append(probs)else:segment_preds = []for model in self.models:with torch.no_grad():outputs = model(mel_spec_tensor)probs = torch.sigmoid(outputs).cpu().numpy().squeeze()segment_preds.append(probs)avg_preds = np.mean(segment_preds, axis=0)all_preds.append(avg_preds)final_preds = np.mean(all_preds, axis=0)else:mel_spec = self.process_audio_segment(segment_audio)mel_spec_tensor = torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)mel_spec_tensor = mel_spec_tensor.to(self.cfg.device)if len(self.models) == 1:with torch.no_grad():outputs = self.models[0](mel_spec_tensor)final_preds = torch.sigmoid(outputs).cpu().numpy().squeeze()else:segment_preds = []for model in self.models:with torch.no_grad():outputs = model(mel_spec_tensor)probs = torch.sigmoid(outputs).cpu().numpy().squeeze()segment_preds.append(probs)final_preds = np.mean(segment_preds, axis=0)predictions.append(final_preds)except Exception as e:print(f"Error processing {audio_path}: {e}")return row_ids, predictionsdef run_inference(self):"""Run inference on all test soundscape audio files.:return: Tuple (all_row_ids, all_predictions) aggregated from all files."""test_files = list(Path(self.cfg.test_soundscapes).glob('*.ogg'))if self.cfg.debug:print(f"Debug mode enabled, using only {self.cfg.debug_count} files")test_files = test_files[:self.cfg.debug_count]print(f"Found {len(test_files)} test soundscapes")all_row_ids = []all_predictions = []for audio_path in tqdm(test_files):row_ids, predictions = self.predict_on_spectrogram(str(audio_path))all_row_ids.extend(row_ids)all_predictions.extend(predictions)return all_row_ids, all_predictionsdef create_submission(self, row_ids, predictions):"""Create the submission dataframe based on predictions.:param row_ids: List of row identifiers for each segment.:param predictions: List of prediction arrays.:return: A pandas DataFrame formatted for submission."""print("Creating submission dataframe...")submission_dict = {'row_id': row_ids}for i, species in enumerate(self.species_ids):submission_dict[species] = [pred[i] for pred in predictions]submission_df = pd.DataFrame(submission_dict)submission_df.set_index('row_id', inplace=True)sample_sub = pd.read_csv(self.cfg.submission_csv, index_col='row_id')missing_cols = set(sample_sub.columns) - set(submission_df.columns)if missing_cols:print(f"Warning: Missing {len(missing_cols)} species columns in submission")for col in missing_cols:submission_df[col] = 0.0submission_df = submission_df[sample_sub.columns]submission_df = submission_df.reset_index()return submission_dfdef smooth_submission(self, submission_path):"""Post-process the submission CSV by smoothing predictions to enforce temporal consistency.For each soundscape (grouped by the file name part of 'row_id'), each row's predictionsare averaged with those of its neighbors using defined weights.:param submission_path: Path to the submission CSV file."""print("Smoothing submission predictions...")sub = pd.read_csv(submission_path)cols = sub.columns[1:]# Extract group names by splitting row_id on the last underscoregroups = sub['row_id'].str.rsplit('_', n=1).str[0].valuesunique_groups = np.unique(groups)for group in unique_groups:# Get indices for the current groupidx = np.where(groups == group)[0]sub_group = sub.iloc[idx].copy()predictions = sub_group[cols].valuesnew_predictions = predictions.copy()if predictions.shape[0] > 1:# Smooth the predictions using neighboring segmentsnew_predictions[0] = (predictions[0] * 0.8) + (predictions[1] * 0.2)new_predictions[-1] = (predictions[-1] * 0.8) + (predictions[-2] * 0.2)for i in range(1, predictions.shape[0] - 1):new_predictions[i] = (predictions[i - 1] * 0.2) + (predictions[i] * 0.6) + (predictions[i + 1] * 0.2)# Replace the smoothed values in the submission dataframesub.iloc[idx, 1:] = new_predictionssub.to_csv(submission_path, index=False)print(f"Smoothed submission saved to {submission_path}")def run(self):"""Main method to execute the complete inference pipeline.This method:- Loads the pre-trained models.- Processes test audio files and runs predictions.- Creates the submission CSV.- Applies smoothing to the predictions."""start_time = time.time()print("Starting BirdCLEF-2025 inference...")print(f"TTA enabled: {self.cfg.use_tta} (variations: {self.cfg.tta_count if self.cfg.use_tta else 0})")self.load_models()if not self.models:print("No models found! Please check model paths.")returnprint(f"Model usage: {'Single model' if len(self.models) == 1 else f'Ensemble of {len(self.models)} models'}")row_ids, predictions = self.run_inference()submission_df = self.create_submission(row_ids, predictions)submission_path = 'submission.csv'submission_df.to_csv(submission_path, index=False)print(f"Initial submission saved to {submission_path}")# Apply smoothing on the submission predictions.self.smooth_submission(submission_path)end_time = time.time()print(f"Inference completed in {(end_time - start_time) / 60:.2f} minutes")# Run the BirdCLEF2025 Pipeline:
if __name__ == "__main__":cfg = CFG()print(f"Using device: {cfg.device}")pipeline = BirdCLEF2025Pipeline(cfg)pipeline.run()

训练代码

由于想要自己训练一个模型,所以另外写了一个train.py
注意其中的

    train_audio_dir = '/root/autodl-tmp/BirdCLEF2025/train_audio' train_csv = '/root/autodl-tmp/BirdCLEF2025/train.csv'taxonomy_csv = '/root/autodl-tmp/BirdCLEF2025/taxonomy.csv' output_dir = ""

需要修改为你实际存放数据的位置。

以下是完整的train.py。如果报有关多线程的错,把TrainCFG中的num_workers设置成0就好。

(因为这部分我也没太搞懂)

# train.py
import os
import pandas as pd
import numpy as np
import librosa
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from tqdm.auto import tqdm# 继承test.py中的原始组件
from test import CFG, BirdCLEFModelimport warnings  # 必须放在最顶部
warnings.filterwarnings("ignore")  # 忽略所有警告# ---------------------- 扩展训练配置 ----------------------
class TrainCFG(CFG):"""新增训练专用参数"""# 数据路径需要覆盖父类配置train_audio_dir = '/root/autodl-tmp/BirdCLEF2025/train_audio' # "./data/birdclef-2025/train_audio"train_csv = '/root/autodl-tmp/BirdCLEF2025/train.csv'  # "./data/birdclef-2025/train.csv"taxonomy_csv = '/root/autodl-tmp/BirdCLEF2025/taxonomy.csv' # './data/birdclef-2025/taxonomy.csv'output_dir = "./checkpoints"# 训练参数device = "cuda" # if torch.cuda.is_available() else "cpu"num_epochs = 20lr = 1e-4batch_size = 256num_workers = 4num_folds = 5seed = 42# 标签平滑参数label_smoothing = 0.05# 混合精度训练use_amp = True# ---------------------- 核心数据处理器 ----------------------
class BirdDataset(Dataset):def __init__(self, cfg, df, audio_dir, is_train=True):"""保持与test.py中spectrogram生成逻辑一致:param df: 从train.csv加载的DataFrame"""self.cfg = cfgself.df = df.reset_index(drop=True)self.audio_dir = audio_dirself.is_train = is_train# 从taxonomy获取标签映射taxonomy = pd.read_csv(cfg.taxonomy_csv)self.label_mapping = {row['primary_label']: idxfor idx, row in taxonomy.iterrows()}print(f"Total classes: {len(self.label_mapping)}")def __len__(self):return len(self.df)def _load_audio(self, filename):"""严格保持与test.py相同的音频加载逻辑"""audio_path = os.path.join(self.audio_dir, filename)# 异常处理与test.py一致try:audio, _ = librosa.load(audio_path, sr=self.cfg.FS)if np.isnan(audio).any():audio = np.nan_to_num(audio, nan=np.mean(audio))except Exception as e:print(f"Error loading {audio_path}: {e}")audio = np.zeros(self.cfg.FS * 5)return audiodef _process_segment(self, audio):"""严格复制test.py中的频谱生成代码"""# 填充逻辑需要完全相同if len(audio) < self.cfg.FS * self.cfg.WINDOW_SIZE:audio = np.pad(audio,(0, self.cfg.FS * self.cfg.WINDOW_SIZE - len(audio)),mode='constant')# Mel频谱生成参数完全一致mel_spec = librosa.feature.melspectrogram(y=audio,sr=self.cfg.FS,n_fft=self.cfg.N_FFT,hop_length=self.cfg.HOP_LENGTH,n_mels=self.cfg.N_MELS,fmin=self.cfg.FMIN,fmax=self.cfg.FMAX,power=2.0)mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min() + 1e-8)# 调整尺寸方式与test.py完全一致return cv2.resize(mel_spec_norm, self.cfg.TARGET_SHAPE, interpolation=cv2.INTER_LINEAR)def __getitem__(self, idx):row = self.df.iloc[idx]# 1.音频加载与预处理audio = self._load_audio(row['filename'])# 2.保持数据增强与test.py的兼容性# (注意:训练时需要自定义增广,但推理时不应启用)if self.is_train:# 随机时间裁剪(保持核心逻辑但扩展为训练模式)if len(audio) > self.cfg.FS * self.cfg.WINDOW_SIZE:start = np.random.randint(0, len(audio) - self.cfg.FS * self.cfg.WINDOW_SIZE)audio = audio[start: start + self.cfg.FS * self.cfg.WINDOW_SIZE]# 3.严格使用test.py频谱生成方法spec = self._process_segment(audio)  # shape (256,256)# 4.目标生成(保持与模型输出的206类一致)target = torch.zeros(len(self.label_mapping), dtype=torch.float32)primary_idx = self.label_mapping.get(row['primary_label'], -1)if primary_idx != -1:target[primary_idx] = 1.0 - self.cfg.label_smoothingtarget += self.cfg.label_smoothing / len(target)return {'spec': torch.tensor(spec).unsqueeze(0),  # shape [1,256,256]'target': target  # shape [206]}# ---------------------- 训练循环 ----------------------
def train_fn(cfg, model, train_loader, optimizer, criterion):model.train()total_loss = 0.0progress = tqdm(train_loader, desc="Training", leave=False)scaler = torch.cuda.amp.GradScaler(enabled=cfg.use_amp)for batch in progress:specs = batch['spec'].to(cfg.device)  # shape [B,1,256,256]targets = batch['target'].to(cfg.device)  # shape [B,206]optimizer.zero_grad()with torch.cuda.amp.autocast(enabled=cfg.use_amp):outputs = model(specs)  # 完全保留test.py的forward逻辑loss = criterion(outputs, targets)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()total_loss += loss.item()progress.set_postfix(loss=loss.item())return total_loss / len(train_loader)def validate_fn(cfg, model, val_loader, criterion):model.eval()total_loss = 0.0progress = tqdm(val_loader, desc="Validating", leave=False)with torch.no_grad():for batch in progress:specs = batch['spec'].to(cfg.device)targets = batch['target'].to(cfg.device)outputs = model(specs)loss = criterion(outputs, targets)total_loss += loss.item()return total_loss / len(val_loader)# ---------------------- 主流程 ----------------------
def main():cfg = TrainCFG()os.makedirs(cfg.output_dir, exist_ok=True)# 确保不同来源的配置同步cfg.TARGET_SHAPE = (256, 256)  # 与test.py完全一致torch.manual_seed(cfg.seed)# 加载数据train_df = pd.read_csv(cfg.train_csv)taxonomy = pd.read_csv(cfg.taxonomy_csv)assert len(taxonomy) == 206, "Taxonomy类数应与模型输出一致"# Cross-validation训练循环skf = StratifiedKFold(n_splits=cfg.num_folds)for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df['primary_label'])):print(f"\n{'=' * 25} Fold {fold + 1}/{cfg.num_folds} {'=' * 25}")# 数据加载器print('loading dataset...')train_ds = BirdDataset(cfg, train_df.iloc[train_idx], cfg.train_audio_dir)val_ds = BirdDataset(cfg, train_df.iloc[val_idx], cfg.train_audio_dir, is_train=False)train_loader = DataLoader(train_ds,batch_size=cfg.batch_size,shuffle=True,num_workers=0,#cfg.num_workers,pin_memory=True)val_loader = DataLoader(val_ds,batch_size=cfg.batch_size * 2,shuffle=False,num_workers=0,#cfg.num_workers,)# 初始化与test.py完全一致的模型结构print('constructing MODEL...')model = BirdCLEFModel(cfg, num_classes=len(taxonomy)).to(cfg.device)optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)criterion = torch.nn.BCEWithLogitsLoss()  # 使用与sigmoid推理一致的目标函数# 训练循环best_val_loss = float('inf')for epoch in range(1, cfg.num_epochs + 1):print(f"Epoch {epoch}/{cfg.num_epochs}")train_loss = train_fn(cfg, model, train_loader, optimizer, criterion)val_loss = validate_fn(cfg, model, val_loader, criterion)# 保存最佳模型(与test.py加载格式完全兼容)if val_loss < best_val_loss:best_val_loss = val_lossckpt_path = os.path.join(cfg.output_dir, f"best_fold{fold}.pth")torch.save({'model_state_dict': model.state_dict(),'config': vars(cfg)}, ckpt_path)print(f"Fold {fold} New best model saved (val_loss={val_loss:.4f})")print(f"Fold {fold} completed. Best val loss: {best_val_loss:.4f}")if __name__ == "__main__":main()

在代码中学:

num_folds(折数)通常指交叉验证中的子集划分数量,用于评估模型的泛化性能。以下是详细解释:


一、核心作用

  1. 数据利用率优化
    将数据集划分为K个子集(K=num_folds),进行K次训练/验证,每次用 K-1个子集训练1个子集验证,充分利用有限数据。

  2. 评估稳定性增强
    通过多个不同验证集的平均结果,减少因数据划分随机性带来的评估偏差。


二、常用场景

场景应用方式
交叉验证训练num_folds=5, 运行5次训练后平均结果
集成学习每折训练一个子模型,最终预测为多模型投票或平均
超参数调优在每折中搜索最佳参数,选择平均性能最优的配置
小数据集验证数据量少时提高验证可靠性(常用num_folds=5/10

三、工作流程示例(5折交叉验证)

数据集划分:
原始数据 ➜ 划分为5等份(F1~F5)

训练轮次训练集验证集评估模型
第1折F2+F3+F4+F5F1Model_1
第2折F1+F3+F4+F5F2Model_2
第3折F1+F2+F4+F5F3Model_3
第4折F1+F2+F3+F5F4Model_4
第5折F1+F2+F3+F4F5Model_5

最终性能:
取5次验证结果的均值(如准确率、F1分数等)

在训练中学:

发现随着训练时间增加,显存占用明显增多,这是为啥?


http://www.mrgr.cn/news/99146.html

相关文章:

  • Python多任务编程:进程全面详解与实战指南
  • 实现AWS Lambda函数安全地请求企业内部API返回数据
  • 【C++详解】C++入门(一)命名空间、缺省参数、函数重载
  • Linux安装mysql_exporter
  • 第 28 场 蓝桥月赛
  • 线性DP:最长上升子序列(子序列可不连续,子数组必须连续)
  • C++ 模块化编程(Modules)在大规模系统中的实践难点
  • acwing--动态规划【线性dp】4/20、4/21
  • 大数据应用开发——大数据平台集群部署(四)
  • 机器学习专栏(4):从数据饥荒到模型失控,破解AI训练的七大生死劫
  • 分布类相关的可视化图像
  • 基于maven-jar-plugin打造一款自动识别主类的maven打包插件
  • 单元测试的一般步骤
  • 20. git diff
  • 超简单的git学习教程
  • Spring Boot 项目中发布流式接口支持实时数据向客户端推送
  • 硬件电路(24)-NE555振荡电路
  • vue的基本结构
  • 用python脚本怎么实现:把一个文件夹里面.png文件没有固定名称,复制到另外一个文件夹按顺序命名?
  • 强制重装及验证onnxruntime-gpu是否正确工作