小编典典

将索引数组转换为 1-hot 编码的 numpy 数组

all

假设我有一个 1d numpy 数组

a = array([1,0,3])

我想将其编码为 2D one-hot 数组

b = array([[0,1,0,0], [1,0,0,0], [0,0,0,1]])

有没有快速的方法来做到这一点?比循环a设置 的元素更快b,也就是说。


阅读 61

收藏
2022-04-19

共1个答案

小编典典

您的数组a定义了输出数组中非零元素的列。您还需要定义行,然后使用精美的索引:

>>> a = np.array([1, 0, 3])
>>> b = np.zeros((a.size, a.max()+1))
>>> b[np.arange(a.size),a] = 1
>>> b
array([[ 0.,  1.,  0.,  0.],
       [ 1.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  1.]])
2022-04-19