Bert 模型复杂度

我们通常用模型包含的参数量计算量来衡量一个模型的复杂度。参数量指的是模型学习参数数量,它决定了模型的大小,以及内存资源的占用,当然,在训练过程中,模型的实际内存使用量并不仅仅由参数量来决定

模型的计算量指的是浮点数运算数量,记作 FLOPs (Floating Point Operations),注意和 FLOPS (Floating Point Operations Per Second) 的区别。

1. Bert 参数量

先根据 transformer 库中 BertModel 的实现,了解下其主要包含的部分:

class BertModel(BertPreTrainedModel):

    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        # ... 省略
        # 1. 词嵌入部分
        self.embeddings = BertEmbeddings(config)
        # 2. 编码器部分
        self.encoder = BertEncoder(config)
        # 3. 输出层部分
        self.pooler = BertPooler(config) if add_pooling_layer else None
        # ... 省略

1.1 词嵌入部分

class BertEmbeddings(nn.Module):

    def __init__(self, config):
        super().__init__()
        # 1. 词编码
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        # 2. 位置编码
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        # 3. 段编码
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        # 3. 规范化层
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # ... 省略

以 bert-base-chinese 预训练模型为例,此处需要的部分配置信息如下:

BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "hidden_size": 768,
  "max_position_embeddings": 512,
  "type_vocab_size": 2,
  "vocab_size": 21128
}

计算过程如下:

  • 词编码: vocab_size*hidden_size = 21128*768 = 16226304
  • 位置编码: max_position_embeddings*hidden_size = 512*768 = 393216
  • 段编码: type_vocab_size*hidden_size = 2*768 = 1536
  • 规范化层:hidden_size*2 = 768*2 = 1536

最终得到,词嵌入部分总参数量为:

16226304 + 393216 + 1536 + 1536 = 16622592

1.2 编码器部分

class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])

以 bert-base-chinese 预训练模型为例,此处需要的部分配置信息如下:

BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "hidden_size": 768,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "intermediate_size": 3072
}

从代码可以看到编码器部分包含 12 个重复的层,我们只需要统计出单个层的参数量即可得到总参数量。BertLayer 的组成部分:

class BertLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 1. 注意力层
        self.attention = BertAttention(config)
        # 2. 中间结果
        self.intermediate = BertIntermediate(config)
        # 3. 输出层
        self.output = BertOutput(config)

注意力层的参数量为:

  • 注意力层 12 个注意力头的参数量:(768*(768/12) + 64)*3*12 = 1771776
  • 注意力层的输出层包括一个线性层,参数量:768*768+768 = 590592,一个规范化层,参数量:768*2=1536,共计:590592+1536 = 592128
  • 注意力层总参数量为:1771776+592128 = 2363904

中间结果层的参数量:768*3072+3072 = 2362368

输出层的参数量:3072*768+768 + 768*2 = 2361600

最终编码器层的总参数量:

(2363904 + 2362368 + 2361600) * 12 = 85054464

1.3 输出层部分

class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

参数量:768*768+768 = 590592

1.4 总参数量

16622592 + 85054464 + 590592 = 102267648
# 占用内存
102267648*4/1024/1024 = 390.12M

程序计算:

from transformers import BertModel


model = BertModel.from_pretrained('bert-base-chinese')

total = 0
for name, parameter in model.named_parameters():
    total += parameter.numel()

# 102267648
print(total)

2. 模型计算量

模型浮点计算量的一些单位表示如下:

  1. FLOPs 系列:
    • MFLOPS(mega-FLOPs)等于一百万(10^6)次的浮点运算
    • GFLOPS(giga-FLOPs)等于十亿(10^9)次的浮点运算
    • TFLOPS(tera-FLOPs)等于一万亿(10^12)次的浮点运算
    • PFLOPS(peta-FLOPs)等于一千万亿(10^15)次的浮点运算
    • EFLOPS(exa-FLOPs)等于一百京(10^18)次的浮点运算
    • ZFLOPS(zetta-FLOPs)等于十万京(10^21)次的浮点运算
  2. MACs 系列:
    • MMACs(mega-MACs)等于一百万(10^6)次的乘和加浮点运算
    • GMACs(giga-FLOPs)等于十亿(10^9)次的乘和加浮点运算

import torch
import torch.nn as nn
from transformers import BertModel
from transformers import BertConfig
import transformers

# pip install fvcore
# https://github.com/facebookresearch/fvcore
# https://github.com/facebookresearch/fvcore/blob/main/tests/test_flop_count.py
from fvcore.nn.flop_count import flop_count
# pip install thop
from thop import profile
# pip install ptflops
# https://github.com/sovrasov/flops-counter.pytorch
from ptflops import get_model_complexity_info


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(1024, 1024, bias=False)
        self.linear2 = nn.Linear(1024, 2048, bias=False)

    def forward(self, inputs):
        inputs = self.linear1(inputs)
        inputs = nn.Tanh()(inputs)
        inputs = self.linear2(inputs)
        inputs = nn.Tanh()(inputs)
        inputs = nn.Sigmoid()(inputs)
        return inputs


# 初始化模型
model = Net()

# 1. thop
def test01():
    inputs = torch.randn(size=[1, 1024])
    flops, params = profile(model=model, inputs=(inputs,))
    print('Net 模型计算量:', flops, 'flop')


# 2. ptflops
def test02():
    # 输出的是 MACs
    flops, params = get_model_complexity_info(model=model, input_res=(1, 1024), print_per_layer_stat=False, as_strings=False, flops_units='Flop')
    print('Net 模型计算量:', flops, 'flop')


# 3. fvcore
def test03():
    inputs = torch.randn(size=[1, 1024])
    flops, counter = flop_count(model=model, inputs=(inputs,))
    # 输出单位是 GFLOPs
    print('Net 模型计算量:', flops, 'GFLOPs', flops['linear']*1e9, 'flop')


if __name__ == '__main__':
    test01()
    print('-' * 30)
    test02()
    print('-' * 30)
    test03()

程序执行结果:

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Net 模型计算量: 3145728.0 flop
------------------------------
Net 模型计算量: 3145728.0 flop
------------------------------
Net 模型计算量: defaultdict(<class 'float'>, {'linear': 0.003145728}) GFLOPs 3145728.0 flop
Unsupported operator aten::tanh encountered 2 time(s)
Unsupported operator aten::sigmoid encountered 1 time(s)

未经允许不得转载:一亩三分地 » Bert 模型复杂度
评论 (0)

8 + 4 =