前言

最近在写AE-RL的项目,莫名其妙写了一堆bug,这就是其中的一个,于是乎记录下来。

这是由于在transform里面,执行resize()函数的时候,数据不为PLI或者Tensor引起的。

正文

1. Unexpected type <class 'numpy.ndarray'>

莫名奇妙的发生了以下的报错:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[8], line 50
     47     print(f'Accuracy on test set: {100 * correct / total:.2f}%')
     49 # 训练模型
---> 50 train_model(10)
     52 # 测试模型
     53 test_model()

Cell In[8], line 18, in train_model(num_epochs)
     16 for epoch in range(num_epochs):
     17     running_loss = 0.0
---> 18     for images, labels in train_loader:
     19         images, labels = images.to(device), labels.to(device)
     21         # 前向传播

File /lib/python3.12/site-packages/torch/utils/data/dataloader.py:631, in _BaseDataLoaderIter.__next__(self)
    628 if self._sampler_iter is None:
    629     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    630     self._reset()  # type: ignore[call-arg]
--> 631 data = self._next_data()
    632 self._num_yielded += 1
    633 if self._dataset_kind == _DatasetKind.Iterable and \
    634         self._IterableDataset_len_called is not None and \
    635         self._num_yielded > self._IterableDataset_len_called:

File /lib/python3.12/site-packages/torch/utils/data/dataloader.py:675, in _SingleProcessDataLoaderIter._next_data(self)
    673 def _next_data(self):
    674     index = self._next_index()  # may raise StopIteration
--> 675     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    676     if self._pin_memory:
    677         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File /lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py:51, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

File /sft_train/data/dataset.py:19, in CustomImageDataset.__getitem__(self, idx)
     17 label = self.labels[idx]
     18 if self.transform:
---> 19     image = self.transform(image)
     20 return image, label

File /lib/python3.12/site-packages/torchvision/transforms/transforms.py:95, in Compose.__call__(self, img)
     93 def __call__(self, img):
     94     for t in self.transforms:
---> 95         img = t(img)
     96     return img

File /lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File /lib/python3.12/site-packages/torchvision/transforms/transforms.py:354, in Resize.forward(self, img)
    346 def forward(self, img):
    347     """
    348     Args:
    349         img (PIL Image or Tensor): Image to be scaled.
   (...)
    352         PIL Image or Tensor: Rescaled image.
    353     """
--> 354     return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)

File /lib/python3.12/site-packages/torchvision/transforms/functional.py:456, in resize(img, size, interpolation, max_size, antialias)
    450     if max_size is not None and len(size) != 1:
    451         raise ValueError(
    452             "max_size should only be passed if size specifies the length of the smaller edge, "
    453             "i.e. size should be an int or a sequence of length 1 in torchscript mode."
    454         )
--> 456 _, image_height, image_width = get_dimensions(img)
    457 if isinstance(size, int):
    458     size = [size]

File /lib/python3.12/site-packages/torchvision/transforms/functional.py:80, in get_dimensions(img)
     77 if isinstance(img, torch.Tensor):
     78     return F_t.get_dimensions(img)
---> 80 return F_pil.get_dimensions(img)

File /lib/python3.12/site-packages/torchvision/transforms/_functional_pil.py:31, in get_dimensions(img)
     29     width, height = img.size
     30     return [channels, height, width]
---> 31 raise TypeError(f"Unexpected type {type(img)}")

TypeError: Unexpected type <class 'numpy.ndarray'>

查了一下这是由于,Resize()变换函数期望输入的图像是PIL.Imagetorch.Tensor,但是这里传入的是numpy.ndarray类型的图像。这导致了 Resize()函数无法获取图像的维度。

查了一下自定义的transform函数:

mnist_train_transform = transforms.Compose([
    transforms.Resize((28, 28)),  
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))  
])

mnist_val_test_transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

确实如此,Resize()ToTensor()的前面。于是乎,把ToTensor()放在Resize()前面就可以了。

2. 解决办法

解决办法就是把ToTensor()Resize(),换一下位置:

mnist_train_transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Resize((28, 28)),  
    transforms.Normalize((0.5,), (0.5,))  
])

mnist_val_test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((28, 28)),
    transforms.Normalize((0.5,), (0.5,))
])

如图,成功!
1728830401764.jpg

3. 题外话

pytorch什么懒猪,2022年的TODO,现在都没有写完

1728830032891.jpg

2022年的open issue
image.png

总结

下次所有transform,尽量把

  1. transforms.ToTensor()
  2. transforms.ToPILImage()

这两个写在前面把,毕竟后面用的可能都是pytorch带的变换。

参考

[1] ChatGPT
[2] 自己
[3] Bug in dataloader iterator found by mypy #76750

Q.E.D.


立志做一个有趣的碳水化合物