在MXNet中,NDArray
是一个类,也是存储和变换数据的主要工具。如果你之前用过NumPy,你会发现NDArray
和NumPy的多维数组非常类似。然而,NDArray
提供GPU计算和自动求梯度等更多功能,这些使NDArray
更加适合深度学习。类似于TensorFlow的tensor与pytorch中的variable,学习NumPy操作方式,实现GPU计算,由于NumPy不支持GPU。
有时候我们需要将NDArray
和NumPy的多维数组相互转换来实现目标功能。在MXNet中可以通过array
函数和asnumpy
函数令数据在NDArray
和NumPy格式之间相互变换。下面将NDArray
实例变换成NumPy实例。
In [1]:
import mxnet as mx
In [2]:
x_nd = mx.nd.arange(12).reshape([-1,3])
x_nd
Out[2]:
[[ 0. 1. 2.]
[ 3. 4. 5.]
[ 6. 7. 8.]
[ 9. 10. 11.]]
<NDArray 4x3 @cpu(0)>
In [3]:
x_np = x_nd.asnumpy()
x_np
Out[3]:
array([[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.]], dtype=float32)
再将NumPy实例变换成NDArray
实例。
In [4]:
x_nd1 = mx.nd.array(x_np)
x_nd1
Out[4]:
[[ 0. 1. 2.]
[ 3. 4. 5.]
[ 6. 7. 8.]
[ 9. 10. 11.]]
<NDArray 4x3 @cpu(0)>