小编典典

“Flatten”在 Keras 中的作用是什么?

all

我试图了解该Flatten功能在 Keras 中的作用。下面是我的代码,它是一个简单的两层网络。它接收形状为 (3, 2) 的二维数据,输出形状为
(1, 4) 的一维数据:

model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

x = np.array([[[1, 2], [3, 4], [5, 6]]])

y = model.predict(x)

print y.shape

这打印出y具有形状 (1, 4)。但是,如果我删除这Flatten条线,那么它会打印出y形状为 (1, 3, 4)。

我不明白这一点。根据我对神经网络的理解,这个model.add(Dense(16, input_shape=(3, 2)))函数是创建一个隐藏的全连接层,有 16 个节点。这些节点中的每一个都连接到每个 3x2 输入元素。因此,第一层输出的 16
个节点已经是“平坦的”。因此,第一层的输出形状应该是 (1, 16)。然后,第二层将其作为输入,输出形状为 (1, 4) 的数据。

因此,如果第一层的输出已经是“平坦的”并且形状为 (1, 16),我为什么需要进一步将其展平?


阅读 71

收藏
2022-08-24

共1个答案

小编典典

如果你阅读 Keras 文档条目Dense,你会看到这个调用:

Dense(16, input_shape=(5,3))

将产生一个Dense具有 3 个输入和 16 个输出的网络,它们将独立应用于 5 个步骤中的每一个。因此,如果D(x)将 3 维向量转换为 16
维向量,您将从层得到的输出将是一系列向量:[D(x[0,:]), D(x[1,:]),..., D(x[4,:])]带有 shape (5, 16)。为了获得您指定的行为,您可以首先Flatten将输入输入到 15 维向量,然后应用Dense

model = Sequential()
model.add(Flatten(input_shape=(3, 2)))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

编辑: 有些人难以理解 - 这里有一张解释图片:

在此处输入图像描述

2022-08-24