我们通常用模型包含的参数量和计算量来衡量一个模型的复杂度。参数量指的是模型学习参数数量,它决定了模型的大小,以及内存资源的占用,当然,在训练过程中,模型的实际内存使用量并不仅仅由参数量来决定
模型的计算量指的是浮点数运算数量,记作 FLOPs (Floating Point Operations),注意和 FLOPS (Floating Point Operations Per Second) 的区别。
1. Bert 参数量
先根据 transformer 库中 BertModel 的实现,了解下其主要包含的部分:
class BertModel(BertPreTrainedModel): def __init__(self, config, add_pooling_layer=True): super().__init__(config) # ... 省略 # 1. 词嵌入部分 self.embeddings = BertEmbeddings(config) # 2. 编码器部分 self.encoder = BertEncoder(config) # 3. 输出层部分 self.pooler = BertPooler(config) if add_pooling_layer else None # ... 省略
1.1 词嵌入部分
class BertEmbeddings(nn.Module): def __init__(self, config): super().__init__() # 1. 词编码 self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) # 2. 位置编码 self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # 3. 段编码 self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # 3. 规范化层 self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # ... 省略
以 bert-base-chinese 预训练模型为例,此处需要的部分配置信息如下:
BertConfig { "architectures": [ "BertForMaskedLM" ], "hidden_size": 768, "max_position_embeddings": 512, "type_vocab_size": 2, "vocab_size": 21128 }
计算过程如下:
- 词编码:
vocab_size*hidden_size = 21128*768 = 16226304
- 位置编码:
max_position_embeddings*hidden_size = 512*768 = 393216
- 段编码:
type_vocab_size*hidden_size = 2*768 = 1536
- 规范化层:
hidden_size*2 = 768*2 = 1536
最终得到,词嵌入部分总参数量为:
16226304 + 393216 + 1536 + 1536 = 16622592
1.2 编码器部分
class BertEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
以 bert-base-chinese 预训练模型为例,此处需要的部分配置信息如下:
BertConfig { "architectures": [ "BertForMaskedLM" ], "hidden_size": 768, "num_attention_heads": 12, "num_hidden_layers": 12, "intermediate_size": 3072 }
从代码可以看到编码器部分包含 12 个重复的层,我们只需要统计出单个层的参数量即可得到总参数量。BertLayer 的组成部分:
class BertLayer(nn.Module): def __init__(self, config): super().__init__() # 1. 注意力层 self.attention = BertAttention(config) # 2. 中间结果 self.intermediate = BertIntermediate(config) # 3. 输出层 self.output = BertOutput(config)
注意力层的参数量为:
- 注意力层
12
个注意力头的参数量:(768*(768/12) + 64)*3*12 = 1771776
- 注意力层的输出层包括一个线性层,参数量:
768*768+768 = 590592
,一个规范化层,参数量:768*2=1536
,共计:590592+1536 = 592128
- 注意力层总参数量为:
1771776+592128 = 2363904
中间结果层的参数量:768*3072+3072 = 2362368
输出层的参数量:3072*768+768 + 768*2 = 2361600
最终编码器层的总参数量:
(2363904 + 2362368 + 2361600) * 12 = 85054464
1.3 输出层部分
class BertPooler(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh()
参数量:768*768+768 = 590592
1.4 总参数量
16622592 + 85054464 + 590592 = 102267648 # 占用内存 102267648*4/1024/1024 = 390.12M
程序计算:
from transformers import BertModel model = BertModel.from_pretrained('bert-base-chinese') total = 0 for name, parameter in model.named_parameters(): total += parameter.numel() # 102267648 print(total)
2. 模型计算量
模型浮点计算量的一些单位表示如下:
- FLOPs 系列:
- MFLOPS(mega-FLOPs)等于一百万(10^6)次的浮点运算
- GFLOPS(giga-FLOPs)等于十亿(10^9)次的浮点运算
- TFLOPS(tera-FLOPs)等于一万亿(10^12)次的浮点运算
- PFLOPS(peta-FLOPs)等于一千万亿(10^15)次的浮点运算
- EFLOPS(exa-FLOPs)等于一百京(10^18)次的浮点运算
- ZFLOPS(zetta-FLOPs)等于十万京(10^21)次的浮点运算
- MACs 系列:
- MMACs(mega-MACs)等于一百万(10^6)次的乘和加浮点运算
- GMACs(giga-FLOPs)等于十亿(10^9)次的乘和加浮点运算
import torch import torch.nn as nn from transformers import BertModel from transformers import BertConfig import transformers # pip install fvcore # https://github.com/facebookresearch/fvcore # https://github.com/facebookresearch/fvcore/blob/main/tests/test_flop_count.py from fvcore.nn.flop_count import flop_count # pip install thop from thop import profile # pip install ptflops # https://github.com/sovrasov/flops-counter.pytorch from ptflops import get_model_complexity_info class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.linear1 = nn.Linear(1024, 1024, bias=False) self.linear2 = nn.Linear(1024, 2048, bias=False) def forward(self, inputs): inputs = self.linear1(inputs) inputs = nn.Tanh()(inputs) inputs = self.linear2(inputs) inputs = nn.Tanh()(inputs) inputs = nn.Sigmoid()(inputs) return inputs # 初始化模型 model = Net() # 1. thop def test01(): inputs = torch.randn(size=[1, 1024]) flops, params = profile(model=model, inputs=(inputs,)) print('Net 模型计算量:', flops, 'flop') # 2. ptflops def test02(): # 输出的是 MACs flops, params = get_model_complexity_info(model=model, input_res=(1, 1024), print_per_layer_stat=False, as_strings=False, flops_units='Flop') print('Net 模型计算量:', flops, 'flop') # 3. fvcore def test03(): inputs = torch.randn(size=[1, 1024]) flops, counter = flop_count(model=model, inputs=(inputs,)) # 输出单位是 GFLOPs print('Net 模型计算量:', flops, 'GFLOPs', flops['linear']*1e9, 'flop') if __name__ == '__main__': test01() print('-' * 30) test02() print('-' * 30) test03()
程序执行结果:
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>. Net 模型计算量: 3145728.0 flop ------------------------------ Net 模型计算量: 3145728.0 flop ------------------------------ Net 模型计算量: defaultdict(<class 'float'>, {'linear': 0.003145728}) GFLOPs 3145728.0 flop Unsupported operator aten::tanh encountered 2 time(s) Unsupported operator aten::sigmoid encountered 1 time(s)