正常batch
micro_shape:每一个样本的shape
full_shape: 一个batch的shape
macro_shape:shape[:-len(micro_shape)]
例:文本的full_shape=(b, n, d), 其中b为batch_size,d为嵌入维度, 则macro_shape=(b,), micro_shape=(n, d)
例:图像的full_shape=(b, c, w, h), 其中b为batch_size,c为通道数,w和h分别为宽和高,则micro_shape=(c, w, h), macro_shape=(b, )
macro_shape在大部分情况下就是(b,)
Batch flattening
假设b个样本,micro_shape为
batch flattening 后 full_shape=
macro_shape被约定为(b, )
micro_shape被约定为,即少了一维,那一维被batch flatten占了。
在batch flattening的情况下,原来的shape为的张量,为了与full_shape的数据一起运算,需要被拓展为