您的当前位置:首页正文

记录自己pytorch加载数据遇到的坑们

2024-11-23 来源:个人技术集锦

第0个坑:

自己的类别少一个,但是训练的前33个epoch没有出现问题,到33个之后就中断了。cuda device assert error之类的错误,这种问题看得多了,就知道肯定是类别和输出或者损失函数的weight数量不一致。但是为什么前33个epoch没有问题呢,是因为数据有随机crop,触发到这个错误的可能性很低。另外也体现了我的类别不均衡。

第一个坑:

我拿到手的代码,yuv数据,转成了float32,然后用这个样的数据去转成rgb做数据增强。但是转成float32数据的范围还是uint8的范围,这样转成的rgb数据就是不对的。

第二个坑:

后来我就将float32给去掉了,这样转rgb的时候是完全没有问题的。但是使用pytroch的transform的to tensor的时候。他是uint的类型,所以转成tensor的时候自动除了255,然后再减去均值,得到的数据分布范围就基本上都在-128~-108之间了。这样收敛也很慢。

第三个坑:

这个和第二个基本一样。gt_mask我给的255忽略值,也进行了同样的操作。结果得到了1。1是一个已定义的类别,这样忽略的地方都给我变成了1,肯定收敛不了。这个除1是知道的,但是还是遇到这个问题了。

if isinstance(img, torch.ByteTensor):
    return img.float().div(255)
else:
    return img

排查经验: 在送入网络之前进行可视化。往后排查,看看哪里有问题。最终通过是否收敛再排查一次。

显示全文