前言
莫名其妙发现了这个warning,很奇怪。
发现了,原来是有个一个label 超过了n_classes。
正文
1. Assertion t < n_classes
failed
发现这样的问题:
/lib/python3.12/site-packages/torch/nn/modules/conv.py:456: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at ../aten/src/ATen/native/Convolution.cpp:1031.)
return F.conv2d(input, weight, bias, self.stride,
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [9,0,0] Assertion `t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [10,0,0] Assertion `t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [22,0,0] Assertion `t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [24,0,0] Assertion `t < n_classes` failed.
../aten/src/ATen/native/cuda/Loss.cu:250: nll_loss_forward_reduce_cuda_kernel_2d: block: [0,0,0], thread: [25,0,0] Assertion `t < n_classes` failed.
Training: 0%| | 0/79 [00:00<?, ?it/s]
仔细检查了之后,发现有个255
混在了里面,很神奇。
不同类别数: 10
类别标签: [ 0 1 2 3 4 5 6 7 8 255]
每个类别的数量: [500 500 500 500 500 500 500 500 500 500]
2. 解决办法
发现了之前是这样的code:
# Adjust labels to be zero-based
y_train = y_train - 1 # Now labels are from 0 to 9
y_test = y_test - 1
可能是超过int范围了,从0 变成255了,把这个代码注释掉,就变回这样了
不同类别数: 10
类别标签: [0 1 2 3 4 5 6 7 8 9]
每个类别的数量: [500 500 500 500 500 500 500 500 500 500]
成功,完事儿!问题消失!
总结
以后遇见新的dataset,一定要好好看看label是啥,不然就会发现这样的事情。
参考
[1] ChatGPT
[2] 自己
Q.E.D.