正常batch

micro_shape:每一个样本的shape

full_shape: 一个batch的shape

macro_shapeshape[:-len(micro_shape)]

例:文本的full_shape=(b, n, d), 其中b为batch_size,d为嵌入维度, 则macro_shape=(b,), micro_shape=(n, d)

例:图像的full_shape=(b, c, w, h), 其中b为batch_size,c为通道数,wh分别为宽和高,则micro_shape=(c, w, h), macro_shape=(b, )

macro_shape在大部分情况下就是(b,)

Batch flattening

假设b个样本,micro_shape

batch flattening 后 full_shape=

macro_shape被约定为(b, )

micro_shape被约定为,即少了一维,那一维被batch flatten占了。

在batch flattening的情况下,原来的shape为的张量,为了与full_shape的数据一起运算,需要被拓展为