前言
最近在写AE-RL的项目,莫名其妙写了一堆bug,这就是其中的一个,于是乎记录下来。
这是由于在transform里面,执行resize()函数的时候,数据不为PLI或者Tensor引起的。
正文
1. Unexpected type <class 'numpy.ndarray'>
莫名奇妙的发生了以下的报错:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[8], line 50
47 print(f'Accuracy on test set: {100 * correct / total:.2f}%')
49 # 训练模型
---> 50 train_model(10)
52 # 测试模型
53 test_model()
Cell In[8], line 18, in train_model(num_epochs)
16 for epoch in range(num_epochs):
17 running_loss = 0.0
---> 18 for images, labels in train_loader:
19 images, labels = images.to(device), labels.to(device)
21 # 前向传播
File /lib/python3.12/site-packages/torch/utils/data/dataloader.py:631, in _BaseDataLoaderIter.__next__(self)
628 if self._sampler_iter is None:
629 # TODO(https://github.com/pytorch/pytorch/issues/76750)
630 self._reset() # type: ignore[call-arg]
--> 631 data = self._next_data()
632 self._num_yielded += 1
633 if self._dataset_kind == _DatasetKind.Iterable and \
634 self._IterableDataset_len_called is not None and \
635 self._num_yielded > self._IterableDataset_len_called:
File /lib/python3.12/site-packages/torch/utils/data/dataloader.py:675, in _SingleProcessDataLoaderIter._next_data(self)
673 def _next_data(self):
674 index = self._next_index() # may raise StopIteration
--> 675 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
676 if self._pin_memory:
677 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File /lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py:51, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
---> 51 data = [self.dataset[idx] for idx in possibly_batched_index]
52 else:
53 data = self.dataset[possibly_batched_index]
File /sft_train/data/dataset.py:19, in CustomImageDataset.__getitem__(self, idx)
17 label = self.labels[idx]
18 if self.transform:
---> 19 image = self.transform(image)
20 return image, label
File /lib/python3.12/site-packages/torchvision/transforms/transforms.py:95, in Compose.__call__(self, img)
93 def __call__(self, img):
94 for t in self.transforms:
---> 95 img = t(img)
96 return img
File /lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /lib/python3.12/site-packages/torchvision/transforms/transforms.py:354, in Resize.forward(self, img)
346 def forward(self, img):
347 """
348 Args:
349 img (PIL Image or Tensor): Image to be scaled.
(...)
352 PIL Image or Tensor: Rescaled image.
353 """
--> 354 return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
File /lib/python3.12/site-packages/torchvision/transforms/functional.py:456, in resize(img, size, interpolation, max_size, antialias)
450 if max_size is not None and len(size) != 1:
451 raise ValueError(
452 "max_size should only be passed if size specifies the length of the smaller edge, "
453 "i.e. size should be an int or a sequence of length 1 in torchscript mode."
454 )
--> 456 _, image_height, image_width = get_dimensions(img)
457 if isinstance(size, int):
458 size = [size]
File /lib/python3.12/site-packages/torchvision/transforms/functional.py:80, in get_dimensions(img)
77 if isinstance(img, torch.Tensor):
78 return F_t.get_dimensions(img)
---> 80 return F_pil.get_dimensions(img)
File /lib/python3.12/site-packages/torchvision/transforms/_functional_pil.py:31, in get_dimensions(img)
29 width, height = img.size
30 return [channels, height, width]
---> 31 raise TypeError(f"Unexpected type {type(img)}")
TypeError: Unexpected type <class 'numpy.ndarray'>
查了一下这是由于,Resize()
变换函数期望输入的图像是PIL.Image
或 torch.Tensor
,但是这里传入的是numpy.ndarray
类型的图像。这导致了 Resize()
函数无法获取图像的维度。
查了一下自定义的transform函数:
mnist_train_transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
mnist_val_test_transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
确实如此,Resize()
在ToTensor()
的前面。于是乎,把ToTensor()
放在Resize()
前面就可以了。
2. 解决办法
解决办法就是把ToTensor()
和Resize()
,换一下位置:
mnist_train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((28, 28)),
transforms.Normalize((0.5,), (0.5,))
])
mnist_val_test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((28, 28)),
transforms.Normalize((0.5,), (0.5,))
])
如图,成功!
3. 题外话
pytorch什么懒猪,2022年的TODO,现在都没有写完
2022年的open issue
总结
下次所有transform,尽量把
- transforms.ToTensor()
- transforms.ToPILImage()
这两个写在前面把,毕竟后面用的可能都是pytorch带的变换。
参考
[1] ChatGPT
[2] 自己
[3] Bug in dataloader iterator found by mypy #76750
Q.E.D.