从 Bert 到 BiLSTM 知识蒸馏

我们使用的是 TNEWS 数据集,该数据集来自今日头条的新闻版块,共提取了15个类别的新闻,包括旅游,教育,金融,军事等。数据量:训练集(53,360),验证集(10,000),测试集(10,000),例子:

{“label”: “102”, “label_des”: “news_entertainment”, “sentence”: “江疏影甜甜圈自拍,迷之角度竟这么好看,美吸引一切事物”}

每一条数据有三个属性,从前往后分别是 分类ID,分类名称,新闻字符串(仅含标题)。

关于知识蒸馏的论文:https://arxiv.org/pdf/1903.12136.pdf
数据集链接:https://github.com/CLUEbenchmark/CLUE

实验主要包含以下几个方面:

  1. 微调 bert-base-chinese 模型 ,并对验证集进行评估;
  2. 从零开始训练 bi-lstm 模型,并对验证集进行评估;
  3. 学生 bi-lstm 学习老师的 bert 模型知识,并对验证集进行评估;
  4. 在 CLUE 提交测试集预测结果,得到模型评分。

1. 训练数据处理

数据集目录结构:

|-- README.txt
|-- dev.json
|-- labels.json
|-- test.json
|-- test1.0.json
`-- train.json

数据集截图:

数据处理这部分就做了这么几件事:

  1. sentence 标准化
  2. label 转换为从 0 开始连续数字标签
  3. 以 datasets.Dataset 格式存储

示例代码:

from pyhanlp import JClass
from datasets import Dataset
import json
import pandas as pd
from datasets import DatasetDict
from collections import Counter


def load_data():

    # 处理标签
    label_mappding = {}
    with open('tnews/labels.json') as file:
        for index, line in enumerate(file):
            line = json.loads(line)
            label_mappding[line['label']] = {'label': index, 'label_desc': line['label_desc']}

    # 清洗文本
    normalizer = JClass('com.hankcs.hanlp.dictionary.other.CharTable')
    def clear_text(text):
        text = normalizer.convert(text)
        return text

    # 训练集
    train_data = {'label': [], 'sentence': []}
    vocab = []
    with open('tnews/train.json') as file:
        for line in file:
            line = json.loads(line)
            label = label_mappding[line['label']]['label']
            sentence = clear_text(line['sentence'])
            train_data['label'].append(label)
            train_data['sentence'].append(sentence)
            vocab.extend(sentence)
        train_data = Dataset.from_dict(train_data)

    # 验证集
    valid_data = {'label': [], 'sentence': []}
    with open('tnews/dev.json') as file:
        for line in file:
            line = json.loads(line)
            valid_data['label'].append(label_mappding[line['label']]['label'])
            valid_data['sentence'].append(clear_text(line['sentence']))
        valid_data = Dataset.from_dict(valid_data)

    # 构建词表
    word_freq = Counter(vocab)
    unique_vocab = [word for word in set(vocab) if word_freq[word] > 2]
    unique_vocab.insert(0, '[UNK]')
    unique_vocab.insert(0, '[PAD]')
    open('data/vocab.txt', 'w').write('\n'.join(unique_vocab))
    # 存储标签映射
    pd.DataFrame(label_mappding).transpose().to_csv('data/label.csv')
    # 存储数据集
    DatasetDict({'train': train_data, 'valid': valid_data}).save_to_disk('data')


if __name__ == '__main__':
    load_data()

程序执行结束之后,产生的文件如下:

data
├── dataset_dict.json
├── label.csv
├── train
│   ├── dataset.arrow
│   ├── dataset_info.json
│   └── state.json
├── valid
│   ├── dataset.arrow
│   ├── dataset_info.json
│   └── state.json
└── vocab.txt

2. 微调 Bert 模型

import torch
import torch.nn as nn
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from datasets import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
from torchmetrics import Accuracy
from torch.optim.lr_scheduler import LinearLR
from datasets import load_from_disk


# 初始化分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# 定义计算设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def load_data():

    # 读取数据集
    data = load_from_disk('data')
    train_data, valid_data = data['train'], data['valid']

    def collate_fn(batch_data):

        batch_inputs = []
        batch_labels = []

        for data in batch_data:
            batch_inputs.append(data['sentence'])
            batch_labels.append(data['label'])

        batch_inputs = tokenizer(batch_inputs, padding=True, return_tensors='pt')
        batch_inputs = { key: value.to(device) for key, value in batch_inputs.items() }
        batch_labels = torch.tensor(batch_labels, device=device)

        return batch_inputs, batch_labels

    params = {'batch_size': 16, 'collate_fn': collate_fn}
    train_loader = DataLoader(train_data, **params, shuffle=True)
    valid_loader = DataLoader(valid_data, **params, shuffle=False)

    return train_loader, valid_loader


def train_model():

    train_loader, valid_loader = load_data()
    estimator = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=15)
    estimator = estimator.to(device)
    optimizer = optim.AdamW(estimator.parameters(), lr=1e-5)
    scheduler = LinearLR(optimizer=optimizer, start_factor=0.3, end_factor=1, total_iters=5000)
    num_epoch = 10
    metric = Accuracy().to(device)
    accumulation_steps = 0

    for epoch_idx in range(num_epoch):

        progress = tqdm(range(len(train_loader)))
        total_loss = 0.0
        for batch_inputs, batch_labels in train_loader:
            outputs = estimator(**batch_inputs, labels=batch_labels)
            optimizer.zero_grad()
            outputs.loss.backward()
            optimizer.step()
            scheduler.step()
            progress.set_description('epoch %2d iter loss %.2f' % (epoch_idx + 1, outputs.loss.item()))
            progress.update()
            total_loss += (outputs.loss.item() * len(batch_labels))
        progress.set_description('epoch %2d total loss %.2f' % (epoch_idx + 1, total_loss))
        progress.close()

        with torch.no_grad():
            progress = tqdm(range(len(valid_loader)))
            for batch_inputs, batch_labels in valid_loader:
                outputs = estimator(**batch_inputs)
                metric.update(outputs.logits, batch_labels)
                progress.set_description('epoch %2d valid acc 0.000' % (epoch_idx + 1))
                progress.update()

            acc = metric.compute()
            progress.set_description('epoch %2d valid acc %.3f' % (epoch_idx + 1, acc))
            progress.close()

        # 模型保存
        model_name = 'model1/{}-{:.4f}'.format(epoch_idx + 1, metric.compute())
        estimator.save_pretrained(model_name)
        tokenizer.save_pretrained(model_name)


if __name__ == '__main__':
    train_model()

训练结束之后,获得的模型如下:

1-0.559  2-0.562  3-0.564  4-0.565  5-0.565  6-0.564  7-0.563  8-0.563  9-0.563

前面的数字表示 epoch,后面的数字表示该模型在验证集上的 Acc.

3. 训练 Bi-LSTM 模型

import torch
from transformers import BertTokenizer
import torch.nn as nn
from datasets import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import LinearLR
from torchmetrics import Accuracy
from tqdm import tqdm
import numpy as np
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_sequence
from datasets import load_from_disk
import os


# 初始化分词器
tokenizer = BertTokenizer(vocab_file='data/vocab.txt')
# 定义计算设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def load_data():

    # 读取数据集
    data = load_from_disk('data')
    train_data, valid_data = data['train'], data['valid']

    def collate_fn(batch_data):

        batch_inputs = []
        batch_labels = []

        for data in batch_data:
            batch_inputs.append(data['sentence'])
            batch_labels.append(data['label'])

        batch_inputs = tokenizer(batch_inputs,
                                 padding=False,
                                 return_token_type_ids=False,
                                 return_special_tokens_mask=False,
                                 add_special_tokens=False)['input_ids']
        # 获得批次序列长度
        batch_length = [len(inputs) for inputs in batch_inputs]
        # 根据长度序列排序
        sorted_index = np.argsort(-np.array(batch_length))
        # 对输入和标签、长度重新排序
        sorted_inputs, sorted_labels, sorted_length = [], [], []
        for index in sorted_index:
            sorted_inputs.append(torch.tensor(batch_inputs[index], device=device))
            sorted_labels.append(batch_labels[index])
            sorted_length.append(batch_length[index])

        # 对 pad_sequence 进行 0 填充,并转换为张量
        sorted_inputs = pad_sequence(sorted_inputs, batch_first=True)
        sorted_labels = torch.tensor(sorted_labels, device=device)

        return sorted_inputs, sorted_labels, sorted_length

    params = {'batch_size': 16, 'collate_fn': collate_fn}
    train_loader = DataLoader(train_data, **params, shuffle=True)
    valid_loader = DataLoader(valid_data, **params, shuffle=False)

    return train_loader, valid_loader


class SequneceClassification(nn.Module):

    def __init__(self):
        super(SequneceClassification, self).__init__()
        self.embedds = nn.Embedding(num_embeddings=tokenizer.vocab_size, embedding_dim=512)
        self.encoder = nn.LSTM(input_size=512, hidden_size=512, batch_first=True, bidirectional=True)
        self.outputs = nn.Linear(in_features=512 * 2, out_features=15)

    def forward(self, inputs, lengths=None):

        inputs = self.embedds(inputs)
        if lengths is not None:
            inputs = pack_padded_sequence(inputs, lengths=lengths, batch_first=True)
        # inputs: 各个 Token 的编码
        # cn: 细胞状态
        # hn: 最后一个时间步隐藏状态
        inputs, (hn, cn) = self.encoder(inputs)
        inputs = torch.tanh(hn)
        inputs = inputs.transpose(0, 1)
        inputs = inputs.reshape(inputs.shape[0], -1)
        inputs = self.outputs(inputs)

        return inputs


def train_model():

    train_loader, valid_loader = load_data()
    estimator = SequneceClassification()
    estimator = estimator.to(device)
    optimizer = optim.AdamW(estimator.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = LinearLR(optimizer=optimizer, start_factor=0.3, end_factor=1, total_iters=1000)
    num_epoch = 50
    metric = Accuracy().to(device)
    accumulation_steps = 0

    for epoch_idx in range(num_epoch):

        progress = tqdm(range(len(train_loader)))
        total_loss = 0.0
        for sorted_inputs, sorted_labels, sorted_length in train_loader:
            outputs = estimator(sorted_inputs, sorted_length)
            optimizer.zero_grad()
            loss = criterion(outputs.squeeze(), sorted_labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            progress.set_description('epoch %2d iter loss %.2f' % (epoch_idx + 1, loss.item()))
            progress.update()
            total_loss += (loss.item() * len(sorted_labels))
        progress.set_description('epoch %2d total loss %.2f' % (epoch_idx + 1, total_loss))
        progress.close()


        with torch.no_grad():
            progress = tqdm(range(len(valid_loader)))
            for sorted_inputs, sorted_labels, sorted_length in valid_loader:
                outputs = estimator(sorted_inputs, sorted_length)
                metric.update(outputs.squeeze(), sorted_labels)
                progress.set_description('epoch %2d valid acc 0.000' % (epoch_idx + 1))
                progress.update()

            acc = metric.compute()
            progress.set_description('epoch %2d valid acc %.3f' % (epoch_idx + 1, acc))
            progress.close()

        # 模型保存
        model_name = 'model2/{}-{:.4f}'.format(epoch_idx + 1, metric.compute())
        if not os.path.exists(model_name):
            os.mkdir(model_name)
        torch.save(estimator.state_dict(), model_name + '/bilstm.bin')
        tokenizer.save_pretrained(model_name)


if __name__ == '__main__':
    train_model()

训练结束之后,产生的所有模型如下:

10-0.5016  14-0.5011  19-0.5009  23-0.5008  28-0.5009  32-0.5009  37-0.5011  41-0.5010  46-0.5004  5-0.4994
1-0.4767   15-0.5012  20-0.5009  24-0.5007  29-0.5010  33-0.5009  38-0.5011  42-0.5009  47-0.5004  6-0.4997
11-0.5014  16-0.5010  2-0.4873   25-0.5007  30-0.5010  34-0.5009  39-0.5011  43-0.5008  48-0.5003  7-0.5013
12-0.5009  17-0.5006  21-0.5008  26-0.5007  3-0.4941   35-0.5010  40-0.5010  44-0.5006  49-0.5002  8-0.5019
13-0.5010  18-0.5008  22-0.5007  27-0.5008  31-0.5010  36-0.5011  4-0.4972   45-0.5004  50-0.5000  9-0.5014

最好的模型应该是第 8 epoch 训练的结果 50.19% 的 Acc。

4. BiLSTM 学习 Bert 的知识

这一步实验非常简单,步骤如下:

  1. 先将 Bert 对训练集的输出 logits 进行存储;
  2. BiLSTM 对训练集的输出 logits 和 Bert 输出的 logits 做 MSE 损失;
  3. 输出训练过程中,BiLSTM 对验证集的 Acc.

示例代码:

import torch
import torch.nn as nn
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets import load_from_disk
import torch.optim as optim
import numpy as np
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from torchmetrics import Accuracy
import os

# 定义计算设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_teacher():

    # 初始化分词器
    tokenizer = BertTokenizer.from_pretrained('model1/5-0.565')

    # 读取数据集
    data = load_from_disk('data')
    train_data, valid_data = data['train'], data['valid']

    def collate_fn(batch_data):

        batch_inputs = [data['sentence'] for data in batch_data]
        batch_inputs = tokenizer(batch_inputs, padding=True, return_tensors='pt')
        batch_inputs = { key: value.to(device) for key, value in batch_inputs.items() }

        return batch_inputs, batch_data

    params = {'batch_size': 16, 'collate_fn': collate_fn}
    train_loader = DataLoader(train_data, **params, shuffle=False)

    return train_loader


def teacher_do():

    train_loader = load_teacher()
    estimator = BertForSequenceClassification.from_pretrained('model1/5-0.565', num_labels=15)
    estimator = estimator.to(device)
    estimator.eval()

    progress = tqdm(range(len(train_loader)))
    teacher_knowledge = {'label': [], 'sentence': []}
    for batch_inputs, batch_data in train_loader:
        with torch.no_grad():
            outputs = estimator(**batch_inputs).logits
        for logit, sample in zip(outputs.cpu().numpy().tolist(), batch_data):
            teacher_knowledge['label'].append(logit)
            teacher_knowledge['sentence'].append(sample['sentence'])
        progress.update()

    progress.close()
    # 存储教师知识
    Dataset.from_dict(teacher_knowledge).save_to_disk('knowledge')


class SequneceClassification(nn.Module):

    def __init__(self):
        super(SequneceClassification, self).__init__()
        self.embedds = nn.Embedding(num_embeddings=tokenizer.vocab_size, embedding_dim=512)
        self.encoder = nn.LSTM(input_size=512, hidden_size=512, batch_first=True, bidirectional=True)
        self.outputs = nn.Linear(in_features=512 * 2, out_features=15)

    def forward(self, inputs, lengths=None):

        inputs = self.embedds(inputs)
        if lengths is not None:
            inputs = pack_padded_sequence(inputs, lengths=lengths, batch_first=True)
        # inputs: 各个 Token 的编码
        # cn: 细胞状态
        # hn: 最后一个时间步隐藏状态
        inputs, (hn, cn) = self.encoder(inputs)
        inputs = torch.tanh(hn)
        inputs = inputs.transpose(0, 1)
        inputs = inputs.reshape(inputs.shape[0], -1)
        inputs = self.outputs(inputs)

        return inputs


 # 初始化分词器
tokenizer = BertTokenizer(vocab_file='data/vocab.txt')

def load_student():

    # 读取数据集
    valid_data = load_from_disk('data')['valid']
    train_data = load_from_disk('knowledge')

    def collate_fn(batch_data):

        batch_inputs = []
        batch_labels = []

        for data in batch_data:
            batch_inputs.append(data['sentence'])
            batch_labels.append(data['label'])

        batch_inputs = tokenizer(batch_inputs,
                                 padding=False,
                                 return_token_type_ids=False,
                                 return_special_tokens_mask=False,
                                 add_special_tokens=False)['input_ids']
        # 获得批次序列长度
        batch_length = [len(inputs) for inputs in batch_inputs]
        # 根据长度序列排序
        sorted_index = np.argsort(-np.array(batch_length))
        # 对输入和标签、长度重新排序
        sorted_inputs, sorted_labels, sorted_length = [], [], []
        for index in sorted_index:
            sorted_inputs.append(torch.tensor(batch_inputs[index], device=device))
            sorted_labels.append(batch_labels[index])
            sorted_length.append(batch_length[index])

        # 对 pad_sequence 进行 0 填充,并转换为张量
        sorted_inputs = pad_sequence(sorted_inputs, batch_first=True)
        sorted_labels = torch.tensor(sorted_labels, device=device)

        return sorted_inputs, sorted_labels, sorted_length

    params = {'batch_size': 16, 'collate_fn': collate_fn}
    train_loader = DataLoader(train_data, **params, shuffle=True)
    valid_loader = DataLoader(valid_data, **params, shuffle=False)

    return train_loader, valid_loader


def student_to():

    train_loader, valid_loader = load_student()
    estimator = SequneceClassification().to(device)
    optimizer = optim.Adam(estimator.parameters(), lr=1e-4)
    criterion = nn.MSELoss()
    metric = Accuracy().to(device)
    num_epoch = 20

    for epoch_idx in range(num_epoch):

        progress = tqdm(range(len(train_loader)))
        total_loss = 0.0
        for sorted_inputs, sorted_labels, sorted_length in train_loader:
            outputs = estimator(sorted_inputs, sorted_length)
            optimizer.zero_grad()
            loss = criterion(outputs.squeeze(), sorted_labels)
            loss.backward()
            optimizer.step()
            progress.set_description('epoch %2d iter loss %.2f' % (epoch_idx + 1, loss.item()))
            progress.update()
            total_loss += (loss.item() * len(sorted_labels))
        progress.set_description('epoch %2d total loss %.2f' % (epoch_idx + 1, total_loss))
        progress.close()

        with torch.no_grad():
            progress = tqdm(range(len(valid_loader)))
            for sorted_inputs, sorted_labels, sorted_length in valid_loader:
                outputs = estimator(sorted_inputs, sorted_length)
                metric.update(outputs.squeeze(), sorted_labels)
                progress.set_description('epoch %2d valid acc 0.000' % (epoch_idx + 1))
                progress.update()

            acc = metric.compute()
            progress.set_description('epoch %2d valid acc %.3f' % (epoch_idx + 1, acc))
            progress.close()

        # 模型保存
        model_name = 'model3/{}-{:.4f}'.format(epoch_idx + 1, metric.compute())
        if not os.path.exists(model_name):
            os.mkdir(model_name)
        torch.save(estimator.state_dict(), model_name + '/bilstm.bin')
        tokenizer.save_pretrained(model_name)


if __name__ == '__main__':
    student_to()

训练过程输出的结果:

epoch  1 total loss 197160.26: 100%|████████| 3335/3335 [00:44<00:00, 75.35it/s]
epoch  1 valid acc 0.477: 100%|██████████████| 625/625 [00:04<00:00, 133.95it/s]
epoch  2 total loss 147970.64: 100%|████████| 3335/3335 [00:43<00:00, 75.85it/s]
epoch  2 valid acc 0.494: 100%|██████████████| 625/625 [00:04<00:00, 134.77it/s]
epoch  3 total loss 129521.85: 100%|████████| 3335/3335 [00:44<00:00, 75.79it/s]
epoch  3 valid acc 0.501: 100%|██████████████| 625/625 [00:04<00:00, 134.88it/s]
epoch  4 total loss 114689.55: 100%|████████| 3335/3335 [00:44<00:00, 75.78it/s]
epoch  4 valid acc 0.507: 100%|██████████████| 625/625 [00:04<00:00, 134.75it/s]
epoch  5 total loss 100872.29: 100%|████████| 3335/3335 [00:43<00:00, 75.94it/s]
epoch  5 valid acc 0.510: 100%|██████████████| 625/625 [00:04<00:00, 134.70it/s]
epoch  6 total loss 87712.62: 100%|█████████| 3335/3335 [00:43<00:00, 75.87it/s]
epoch  6 valid acc 0.513: 100%|██████████████| 625/625 [00:04<00:00, 135.19it/s]
epoch  7 total loss 75687.72: 100%|█████████| 3335/3335 [00:44<00:00, 75.77it/s]
epoch  7 valid acc 0.515: 100%|██████████████| 625/625 [00:04<00:00, 134.78it/s]
epoch  8 total loss 64895.03: 100%|█████████| 3335/3335 [00:44<00:00, 75.76it/s]
epoch  8 valid acc 0.517: 100%|██████████████| 625/625 [00:04<00:00, 135.04it/s]
epoch  9 total loss 55662.15: 100%|█████████| 3335/3335 [00:44<00:00, 75.67it/s]
epoch  9 valid acc 0.518: 100%|██████████████| 625/625 [00:04<00:00, 134.84it/s]
epoch 10 total loss 47618.56: 100%|█████████| 3335/3335 [00:43<00:00, 75.80it/s]
epoch 10 valid acc 0.519: 100%|██████████████| 625/625 [00:04<00:00, 134.93it/s]
epoch 11 total loss 40930.74: 100%|█████████| 3335/3335 [00:44<00:00, 75.69it/s]
epoch 11 valid acc 0.519: 100%|██████████████| 625/625 [00:04<00:00, 134.59it/s]
epoch 12 total loss 35400.10: 100%|█████████| 3335/3335 [00:44<00:00, 75.66it/s]
epoch 12 valid acc 0.520: 100%|██████████████| 625/625 [00:04<00:00, 134.85it/s]
epoch 13 total loss 30833.56: 100%|█████████| 3335/3335 [00:44<00:00, 75.73it/s]
epoch 13 valid acc 0.521: 100%|██████████████| 625/625 [00:04<00:00, 134.48it/s]
epoch 14 total loss 26831.78: 100%|█████████| 3335/3335 [00:44<00:00, 75.71it/s]
epoch 14 valid acc 0.521: 100%|██████████████| 625/625 [00:04<00:00, 134.95it/s]
epoch 15 total loss 23764.05: 100%|█████████| 3335/3335 [00:44<00:00, 75.78it/s]
epoch 15 valid acc 0.521: 100%|██████████████| 625/625 [00:04<00:00, 135.02it/s]
epoch 16 total loss 20811.81: 100%|█████████| 3335/3335 [00:44<00:00, 75.77it/s]
epoch 16 valid acc 0.522: 100%|██████████████| 625/625 [00:04<00:00, 134.85it/s]
epoch 17 total loss 18524.72: 100%|█████████| 3335/3335 [00:44<00:00, 75.71it/s]
epoch 17 valid acc 0.522: 100%|██████████████| 625/625 [00:04<00:00, 134.95it/s]
epoch 18 total loss 16605.46: 100%|█████████| 3335/3335 [00:44<00:00, 75.69it/s]
epoch 18 valid acc 0.523: 100%|██████████████| 625/625 [00:04<00:00, 134.85it/s]
epoch 19 total loss 14851.84: 100%|█████████| 3335/3335 [00:44<00:00, 75.72it/s]
epoch 19 valid acc 0.523: 100%|██████████████| 625/625 [00:04<00:00, 134.75it/s]
epoch 20 total loss 13164.29: 100%|█████████| 3335/3335 [00:44<00:00, 75.75it/s]
epoch 20 valid acc 0.523: 100%|██████████████| 625/625 [00:04<00:00, 134.79it/s]

BiLSTM 在验证集上的 Acc 由直接训练得到的 50.19% 上升到 52.3%,看起来还是效果还是很明显的。我觉得,调参或者增加 epoch 数量,还能继续增加 BiLSTM 的 Acc。

我们这里训练 BiLSTM 时使用的是训练集,论文中还提到,可以对原数据进行增强,使得学生模型可以从教师模型中学到更多的暗知识(Dark Knowledge),论文中给出了增强方法。

未经允许不得转载:一亩三分地 » 从 Bert 到 BiLSTM 知识蒸馏
评论 (0)

5 + 8 =