ImageFolder---合并dataset

ImageFolder 用于读取文件夹内的图片与类别,生成Map-style datasets形式的数据集,以便DataLoader迭代。ImageFolder使用格式:

root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png

但是由于各种原因,如原生数据存储格式,某些操作系统(如centos)中单个目录中最大文件数量存在限制(hdf5 yyds)等等,使得数据的存储是分开的,如:

sub_root1/dog/xxx.png
sub_root1/dog/xxy.png
sub_root1/dog/[..
### pseudo-labeling 策略概述 pseudo-labeling 是一种用于半监督学习的技术,在这种技术下,模型不仅依赖于少量标注的数据集进行训练,还能够利用大量未标记的数据来提高泛化能力。该方法的核心在于使用当前模型预测未标记样本的标签,将这些高置信度的预测作为伪标签加入到训练集中[^3]。 #### 原理 pseudo-labeling 的基本原理是在给定的小规模有标签数据上预训练一个初始分类器之后,用这个初步训练好的分类器去推断大规模无标签数据上的类别分布。对于那些被分类器高度确信属于某一类别的实例,则赋予相应的伪标签将其视为真实的标签参与到后续迭代中的再训练过程中。这种方法可以有效缓解因缺乏足够的带标签样例而导致过拟合的风险,同时也增强了模型对未知模式的理解力和适应性[^1]。 #### 应用场景 pseudo-labeling 广泛应用于各种计算机视觉任务中,比如图像识别、目标检测以及分割等领域。此外,在自然语言处理方面也有着广泛的应用案例,例如文本分类、情感分析等。特别是在医疗影像诊断这样的领域里,获取高质量的手动标注成本极高,而采用此策略可以在不增加太多额外开销的前提下显著提升系统的准确性与可靠性[^2]。 #### 实现方式 以下是基于 PyTorch 框架的一个简单实现例子: ```python import torch from torchvision import datasets, transforms from torch.utils.data.sampler import SubsetRandomSampler import numpy as np def train_pseudo_label(model, labeled_loader, unlabeled_loader, optimizer, criterion, device='cuda'): model.train() for (l_data, l_targets), u_data in zip(labeled_loader, unlabeled_loader): # 将已知标签的数据送入GPU/CPU设备 l_data, l_targets = l_data.to(device), l_targets.to(device) # 对未标记的数据也做同样的操作 u_data = u_data.to(device) # 获取未标记数据经过网络后的输出 with torch.no_grad(): outputs_u = model(u_data).detach() # 计算软标签(即概率) soft_labels = torch.softmax(outputs_u / T, dim=1) # 温度系数T控制锐利程度 # 只保留最有可能的真实标签及其对应的索引位置 max_probs, targets_u = torch.max(soft_labels, dim=-1) mask = max_probs.ge(threshold).float() # 设置阈值过滤低质量伪标签 # 合并两部分数据一起更新权重 combined_inputs = torch.cat([l_data, u_data], 0) combined_targets = torch.cat([l_targets, targets_u * mask.long()], 0) logits = model(combined_inputs) loss = criterion(logits, combined_targets) optimizer.zero_grad() loss.backward() optimizer.step() if __name__ == '__main__': transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) dataset_labeled = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) dataset_unlabeled = datasets.ImageFolder('path_to_your_unlabelled_dataset', transform=transform_train) num_train = len(dataset_labeled) indices = list(range(num_train)) split = int(np.floor(valid_size * num_train)) np.random.shuffle(indices) sampler_labeled = SubsetRandomSampler(indices[:split]) sampler_unlabeled = SubsetRandomSampler(indices[split:num_train]) loader_labeled = torch.utils.data.DataLoader( dataset_labeled, batch_size=batch_size, sampler=sampler_labeled, pin_memory=True, ) loader_unlabeled = torch.utils.data.DataLoader( dataset_unlabeled, batch_size=batch_size*mu, sampler=sampler_unlabeled, pin_memory=True, ) net = YourModel().to(device) opt = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) ce_loss_fn = nn.CrossEntropyLoss(reduction="none") for epoch in range(start_epoch, epochs): train_pseudo_label(net, loader_labeled, loader_unlabeled, opt, ce_loss_fn) ``` 在这个代码片段中,`train_pseudo_label()` 函数展示了如何在一个典型的训练循环内应用 pseudo-labeling 技术。这里的关键点包括但不限于:为未标记数据分配伪标签;设置温度参数 `T` 来调整 softmax 输出的概率分布;定义最小置信度 `threshold` 进行筛选以排除不确定性强的例子;最后则是组合来自两个不同来源的数据来进行联合优化[^4]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值