前言
最近在写AE-RL的项目,莫名其妙写了一堆bug,这就是其中的一个,于是乎记录下来。
这是由于在transform里面,执行resize()函数的时候,数据不为PLI或者Tensor引起的。
正文
1. Unexpected type <class 'numpy.ndarray'>
莫名奇妙的发生了以下的报错:
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[8], line 50 47 print(f'Accuracy on test set: {100 * correct / total:.2f}%') 49 # 训练模型 ---> 50 train_model(10) 52 # 测试模型 53 test_model() Cell In[8], line 18, in train_model(num_epochs) 16 for epoch in range(num_epochs): 17 running_loss = 0.0 ---> 18 for images, labels in train_loader: 19 images, labels = images.to(device), labels.to(device) 21 # 前向传播 File /lib/python3.12/site-packages/torch/utils/data/dataloader.py:631, in _BaseDataLoaderIter.__next__(self) 628 if self._sampler_iter is None: 629 # TODO(https://github.com/pytorch/pytorch/issues/76750) 630 self._reset() # type: ignore[call-arg] --> 631 data = self._next_data() 632 self._num_yielded += 1 633 if self._dataset_kind == _DatasetKind.Iterable and \ 634 self._IterableDataset_len_called is not None and \ 635 self._num_yielded > self._IterableDataset_len_called: File /lib/python3.12/site-packages/torch/utils/data/dataloader.py:675, in _SingleProcessDataLoaderIter._next_data(self) 673 def _next_data(self): 674 index = self._next_index() # may raise StopIteration --> 675 data = self._dataset_fetcher.fetch(index) # may raise StopIteration 676 if self._pin_memory: 677 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) File /lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py:51, in _MapDatasetFetcher.fetch(self, possibly_batched_index) 49 data = self.dataset.__getitems__(possibly_batched_index) 50 else: ---> 51 data = [self.dataset[idx] for idx in possibly_batched_index] 52 else: 53 data = self.dataset[possibly_batched_index] File /sft_train/data/dataset.py:19, in CustomImageDataset.__getitem__(self, idx) 17 label = self.labels[idx] 18 if self.transform: ---> 19 image = self.transform(image) 20 return image, label File /lib/python3.12/site-packages/torchvision/transforms/transforms.py:95, in Compose.__call__(self, img) 93 def __call__(self, img): 94 for t in self.transforms: ---> 95 img = t(img) 96 return img File /lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs) 1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(*args, **kwargs) File /lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(*args, **kwargs) 1543 try: 1544 result = None File /lib/python3.12/site-packages/torchvision/transforms/transforms.py:354, in Resize.forward(self, img) 346 def forward(self, img): 347 """ 348 Args: 349 img (PIL Image or Tensor): Image to be scaled. (...) 352 PIL Image or Tensor: Rescaled image. 353 """ --> 354 return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias) File /lib/python3.12/site-packages/torchvision/transforms/functional.py:456, in resize(img, size, interpolation, max_size, antialias) 450 if max_size is not None and len(size) != 1: 451 raise ValueError( 452 "max_size should only be passed if size specifies the length of the smaller edge, " 453 "i.e. size should be an int or a sequence of length 1 in torchscript mode." 454 ) --> 456 _, image_height, image_width = get_dimensions(img) 457 if isinstance(size, int): 458 size = [size] File /lib/python3.12/site-packages/torchvision/transforms/functional.py:80, in get_dimensions(img) 77 if isinstance(img, torch.Tensor): 78 return F_t.get_dimensions(img) ---> 80 return F_pil.get_dimensions(img) File /lib/python3.12/site-packages/torchvision/transforms/_functional_pil.py:31, in get_dimensions(img) 29 width, height = img.size 30 return [channels, height, width] ---> 31 raise TypeError(f"Unexpected type {type(img)}") TypeError: Unexpected type <class 'numpy.ndarray'>
查了一下这是由于,Resize()
变换函数期望输入的图像是PIL.Image
或 torch.Tensor
,但是这里传入的是numpy.ndarray
类型的图像。这导致了 Resize()
函数无法获取图像的维度。
查了一下自定义的transform函数:
mnist_train_transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) mnist_val_test_transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ])
确实如此,Resize()
在ToTensor()
的前面。于是乎,把ToTensor()
放在Resize()
前面就可以了。
2. 解决办法
解决办法就是把ToTensor()
和Resize()
,换一下位置:
mnist_train_transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((28, 28)), transforms.Normalize((0.5,), (0.5,)) ]) mnist_val_test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((28, 28)), transforms.Normalize((0.5,), (0.5,)) ])
3. 题外话
pytorch什么懒猪,2022年的TODO,现在都没有写完
总结
下次所有transform,尽量把
- transforms.ToTensor()
- transforms.ToPILImage()
这两个写在前面把,毕竟后面用的可能都是pytorch带的变换。
参考
[1] ChatGPT
[2] 自己
[3] Bug in dataloader iterator found by mypy #76750
Q.E.D.