前言

最近在码PyTorch的时候,遇见了一个问题,TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.,查询了一下原因是因为调用matplotlib的时候,用了CUDA tensor这样的数据类型,然而matplotlib根本不知道什么是CUDA tensor,所以报了TypeError的错误,参考来自《TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first;'》

正文

1. 报错

screen = em.get_processed_screen()

plt.figure()
plt.imshow(screen.squeeze(0).permute(1,2,0), interpolation='none')
plt.title('Processed screen example')
plt.show()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[38], line 4
      1 screen = em.get_processed_screen()
      3 plt.figure()
----> 4 plt.imshow(screen.squeeze(0).permute(1,2,0), interpolation='none')
      5 plt.title('Processed screen example')
      6 plt.show()

File ~/miniconda3/lib/python3.8/site-packages/matplotlib/pyplot.py:2695, in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, data, **kwargs)
   2689 @_copy_docstring_and_deprecators(Axes.imshow)
   2690 def imshow(
   2691         X, cmap=None, norm=None, *, aspect=None, interpolation=None,
   2692         alpha=None, vmin=None, vmax=None, origin=None, extent=None,
   2693         interpolation_stage=None, filternorm=True, filterrad=4.0,
   2694         resample=None, url=None, data=None, **kwargs):
-> 2695     __ret = gca().imshow(
   2696         X, cmap=cmap, norm=norm, aspect=aspect,
   2697         interpolation=interpolation, alpha=alpha, vmin=vmin,
   2698         vmax=vmax, origin=origin, extent=extent,
   2699         interpolation_stage=interpolation_stage,
   2700         filternorm=filternorm, filterrad=filterrad, resample=resample,
   2701         url=url, **({"data": data} if data is not None else {}),
   2702         **kwargs)
   2703     sci(__ret)
   2704     return __ret

File ~/miniconda3/lib/python3.8/site-packages/matplotlib/__init__.py:1442, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1439 @functools.wraps(func)
   1440 def inner(ax, *args, data=None, **kwargs):
   1441     if data is None:
-> 1442         return func(ax, *map(sanitize_sequence, args), **kwargs)
   1444     bound = new_sig.bind(ax, *args, **kwargs)
   1445     auto_label = (bound.arguments.get(label_namer)
   1446                   or bound.kwargs.get(label_namer))

File ~/miniconda3/lib/python3.8/site-packages/matplotlib/axes/_axes.py:5665, in Axes.imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)
   5657 self.set_aspect(aspect)
   5658 im = mimage.AxesImage(self, cmap=cmap, norm=norm,
   5659                       interpolation=interpolation, origin=origin,
   5660                       extent=extent, filternorm=filternorm,
   5661                       filterrad=filterrad, resample=resample,
   5662                       interpolation_stage=interpolation_stage,
   5663                       **kwargs)
-> 5665 im.set_data(X)
   5666 im.set_alpha(alpha)
   5667 if im.get_clip_path() is None:
   5668     # image does not already have clipping set, clip to axes patch

File ~/miniconda3/lib/python3.8/site-packages/matplotlib/image.py:697, in _ImageBase.set_data(self, A)
    695 if isinstance(A, PIL.Image.Image):
    696     A = pil_to_array(A)  # Needed e.g. to apply png palette.
--> 697 self._A = cbook.safe_masked_invalid(A, copy=True)
    699 if (self._A.dtype != np.uint8 and
    700         not np.can_cast(self._A.dtype, float, "same_kind")):
    701     raise TypeError("Image data of dtype {} cannot be converted to "
    702                     "float".format(self._A.dtype))

File ~/miniconda3/lib/python3.8/site-packages/matplotlib/cbook/__init__.py:709, in safe_masked_invalid(x, copy)
    708 def safe_masked_invalid(x, copy=False):
--> 709     x = np.array(x, subok=True, copy=copy)
    710     if not x.dtype.isnative:
    711         # If we have already made a copy, do the byteswap in place, else make a
    712         # copy with the byte order swapped.
    713         x = x.byteswap(inplace=copy).newbyteorder('N')  # Swap to native order.

File ~/miniconda3/lib/python3.8/site-packages/torch/_tensor.py:970, in Tensor.__array__(self, dtype)
    968     return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
    969 if dtype is None:
--> 970     return self.numpy()
    971 else:
    972     return self.numpy().astype(dtype, copy=False)

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Screen Shot 20230815 at 13.11.09.png

2. 原因

《TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first;'》中解释的很清楚,这是因为matplotlib 并不知道什么是CUDA tensors,然而我们的代码中将CUDA tensors 作为参数放入了matplotlib的函数中的原因,我们只需要在相应的CUDA tensors中 加上后缀.detach().cpu().numpy()即可。

3. 问题解决

问题解决!

screen = em.get_processed_screen()

plt.figure()
plt.imshow(screen.squeeze(0).permute(1,2,0).detach().cpu().numpy(), interpolation='none')
plt.title('Processed screen example')
plt.show()

Screen Shot 20230815 at 13.11.46.png

总结

数据类型,也是很重要的~!

用Pytorch的时候,记得Pytorch有自己的数据类型哦!

参考

[1] TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first;'

Q.E.D.


立志做一个有趣的碳水化合物