前言
最近在码PyTorch的时候,遇见了一个问题,TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
,查询了一下原因是因为调用matplotlib
的时候,用了CUDA tensor这样的数据类型,然而matplotlib
根本不知道什么是CUDA tensor,所以报了TypeError的错误,参考来自《TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first;'》
正文
1. 报错
screen = em.get_processed_screen()
plt.figure()
plt.imshow(screen.squeeze(0).permute(1,2,0), interpolation='none')
plt.title('Processed screen example')
plt.show()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[38], line 4
1 screen = em.get_processed_screen()
3 plt.figure()
----> 4 plt.imshow(screen.squeeze(0).permute(1,2,0), interpolation='none')
5 plt.title('Processed screen example')
6 plt.show()
File ~/miniconda3/lib/python3.8/site-packages/matplotlib/pyplot.py:2695, in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, data, **kwargs)
2689 @_copy_docstring_and_deprecators(Axes.imshow)
2690 def imshow(
2691 X, cmap=None, norm=None, *, aspect=None, interpolation=None,
2692 alpha=None, vmin=None, vmax=None, origin=None, extent=None,
2693 interpolation_stage=None, filternorm=True, filterrad=4.0,
2694 resample=None, url=None, data=None, **kwargs):
-> 2695 __ret = gca().imshow(
2696 X, cmap=cmap, norm=norm, aspect=aspect,
2697 interpolation=interpolation, alpha=alpha, vmin=vmin,
2698 vmax=vmax, origin=origin, extent=extent,
2699 interpolation_stage=interpolation_stage,
2700 filternorm=filternorm, filterrad=filterrad, resample=resample,
2701 url=url, **({"data": data} if data is not None else {}),
2702 **kwargs)
2703 sci(__ret)
2704 return __ret
File ~/miniconda3/lib/python3.8/site-packages/matplotlib/__init__.py:1442, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
1439 @functools.wraps(func)
1440 def inner(ax, *args, data=None, **kwargs):
1441 if data is None:
-> 1442 return func(ax, *map(sanitize_sequence, args), **kwargs)
1444 bound = new_sig.bind(ax, *args, **kwargs)
1445 auto_label = (bound.arguments.get(label_namer)
1446 or bound.kwargs.get(label_namer))
File ~/miniconda3/lib/python3.8/site-packages/matplotlib/axes/_axes.py:5665, in Axes.imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)
5657 self.set_aspect(aspect)
5658 im = mimage.AxesImage(self, cmap=cmap, norm=norm,
5659 interpolation=interpolation, origin=origin,
5660 extent=extent, filternorm=filternorm,
5661 filterrad=filterrad, resample=resample,
5662 interpolation_stage=interpolation_stage,
5663 **kwargs)
-> 5665 im.set_data(X)
5666 im.set_alpha(alpha)
5667 if im.get_clip_path() is None:
5668 # image does not already have clipping set, clip to axes patch
File ~/miniconda3/lib/python3.8/site-packages/matplotlib/image.py:697, in _ImageBase.set_data(self, A)
695 if isinstance(A, PIL.Image.Image):
696 A = pil_to_array(A) # Needed e.g. to apply png palette.
--> 697 self._A = cbook.safe_masked_invalid(A, copy=True)
699 if (self._A.dtype != np.uint8 and
700 not np.can_cast(self._A.dtype, float, "same_kind")):
701 raise TypeError("Image data of dtype {} cannot be converted to "
702 "float".format(self._A.dtype))
File ~/miniconda3/lib/python3.8/site-packages/matplotlib/cbook/__init__.py:709, in safe_masked_invalid(x, copy)
708 def safe_masked_invalid(x, copy=False):
--> 709 x = np.array(x, subok=True, copy=copy)
710 if not x.dtype.isnative:
711 # If we have already made a copy, do the byteswap in place, else make a
712 # copy with the byte order swapped.
713 x = x.byteswap(inplace=copy).newbyteorder('N') # Swap to native order.
File ~/miniconda3/lib/python3.8/site-packages/torch/_tensor.py:970, in Tensor.__array__(self, dtype)
968 return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
969 if dtype is None:
--> 970 return self.numpy()
971 else:
972 return self.numpy().astype(dtype, copy=False)
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
2. 原因
在《TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first;'》中解释的很清楚,这是因为matplotlib
并不知道什么是CUDA tensors,然而我们的代码中将CUDA tensors 作为参数放入了matplotlib
的函数中的原因,我们只需要在相应的CUDA tensors中 加上后缀.detach().cpu().numpy()
即可。
3. 问题解决
问题解决!
screen = em.get_processed_screen()
plt.figure()
plt.imshow(screen.squeeze(0).permute(1,2,0).detach().cpu().numpy(), interpolation='none')
plt.title('Processed screen example')
plt.show()
总结
数据类型,也是很重要的~!
用Pytorch的时候,记得Pytorch有自己的数据类型哦!
参考
Q.E.D.