原型网络(Prototypical Networks)

原型网络是一种能够解决小样本学习问题。其主要思想是将样本映射到一个低维空间中,并在该空间中计算每个类别的原型,然后通过计算测试样本和每个原型之间的距离来进行分类。

Paper:https://arxiv.org/pdf/1703.05175.pdf

训练过程:

  1. 构造输入:随机从小样本的数据集中构造出两个数据集,我们暂且把其中一个叫做支持集、一个叫做查询集。例如:我们有的小数据集中存在 5 个类别,我们从 5 个类别中分别采样 10 个样本,这就是50个样本组成的支持集。再从这 5 个类别剩下的样本中,每个类别采样 15 个样本作为查询集;
  2. 类别原型:将 5 个类别的支持集的 50 个样本送入到编码器得到每个样本的编码表示,再分别计算每个类别的平均编码表示作为对应类别的原型表示;:将训练集中按照类别划分,计算每个类别的平均编码表示,作为该类别的原型。
  3. 计算距离:将查询集中样本送入到编码器得到编码表示,并分别计算与各个原型的距离(欧式距离、余弦相似度);
  4. 损失计算:将得到的距离转换为概率分布表示,使用交叉熵来计算这一次输入的损失;
  5. 重复 1-5 步骤,直至模型收敛。

网络训练完成之后,使用所有的训练样本的平均编码表示作为每个类别的原型表示。

预测过程:

  1. 对输入的样本进行编码
  2. 计算样本和每个原型的距离,并将距离转换为概率分布
  3. 将样本归类到概率最大的类别

示例代码:

import torch.nn as nn
import torch
from datasets import load_from_disk
from transformers import AlbertModel
from torch.utils.data import DataLoader
from collections import Counter
import numpy as np
from transformers import BertTokenizer
import torch.nn.functional as F
import torch.optim as optim
import datasets
datasets.disable_progress_bar()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class PrototypicalNetworks(nn.Module):

    def __init__(self):
        super(PrototypicalNetworks, self).__init__()
        self.encoder = AlbertModel.from_pretrained('albert_chinese_tiny')
        self.prototype = None


    def get_prototype(self, support_title, support_label):

        support_outputs = self.encoder(**support_title)
        sentence_embeddings = support_outputs.last_hidden_state[:, 0]

        support_embeddings = []
        for label in range(2):
            class_embeding = sentence_embeddings[support_label == label]
            support_embeddings.append(torch.mean(class_embeding, dim=0))
        support_embeddings = torch.stack(support_embeddings)
        # support_embeddings = F.normalize(support_embeddings, p=2, dim=-1)

        return support_embeddings


    @torch.no_grad()
    def build_prototype(self, trainset, tokenizer):
        self.eval()
        embeddings = []
        labels = []
        def collate_function(batch_data):
            title = batch_data['title']
            label = batch_data['label']
            title = tokenizer(title, padding='longest', return_token_type_ids=False, return_tensors='pt')
            title = {key: value.to(device) for key, value in title.items()}

            outputs = self.encoder(**title)

            embeddings.extend(outputs.last_hidden_state[:, 0])

            labels.extend(label)

        trainset.map(collate_function, batched=True, batch_size=32)

        embeddings = torch.stack(embeddings)

        a = torch.mean(embeddings[torch.tensor(labels) == 0], dim=0)
        b = torch.mean(embeddings[torch.tensor(labels) == 1], dim=0)
        self.prototype = torch.stack([a, b])
        print('AAAAAA:', self.prototype.shape)


    @torch.no_grad()
    def get_scores(self, trainset, tokenizer):
        self.eval()

        embeddings = []
        labels = []
        def collate_function(batch_data):
            title = batch_data['title']
            label = batch_data['label']
            title = tokenizer(title, padding='longest', return_token_type_ids=False, return_tensors='pt')
            title = {key: value.to(device) for key, value in title.items()}

            outputs = self.encoder(**title)

            embeddings.extend(outputs.last_hidden_state[:, 0])

            labels.extend(label)

        trainset.map(collate_function, batched=True, batch_size=16)
        embeddings = torch.stack(embeddings)

        y_preds = []
        for embedding in embeddings:
            distance = F.pairwise_distance(embedding.expand(self.prototype.size()), self.prototype)
            y_pred = torch.argmax(distance)
            y_preds.append(y_pred.item())


        return y_preds


    def get_query(self, query_title):
        query_outputs = self.encoder(**query_title)
        query_embeddings = query_outputs.last_hidden_state[:, 0]
        # query_embeddings = F.normalize(sentence_embeddings, p=2, dim=-1)
        return query_embeddings

    def get_loss(self, prototype, query, query_label):

        distances = []
        for q in query:
            distance = F.pairwise_distance(q.expand(prototype.size()), prototype)
            distances.append(distance)
        distances = torch.stack(distances)
        loss = F.cross_entropy(distances, query_label)

        return loss


    def forward(self, support_title, support_label, query_title, query_label):
        self.train()
        prototype = self.get_prototype(support_title, support_label)
        query = self.get_query(query_title)
        loss = self.get_loss(prototype, query, query_label)

        return loss


def generate_inputs(traindata, tokenizer):

    # 计算两个类别的样本索引
    class_0_index = torch.argwhere(torch.tensor(traindata['label']) == 0).squeeze().tolist()
    class_1_index = torch.argwhere(torch.tensor(traindata['label']) == 1).squeeze().tolist()

    class_0_support = np.random.choice(class_0_index, 10).tolist()
    class_1_support = np.random.choice(class_1_index, 10).tolist()
    support_index = class_0_support + class_1_support
    support = traindata[support_index]
    support_title = support['title']
    support_label = support['label']
    support_title = tokenizer(support_title, padding='longest', return_token_type_ids=False, return_tensors='pt')
    support_title = {key: value.to(device) for key, value in support_title.items()}
    support_label = torch.tensor(support_label, device=device)

    class_0_query = np.random.choice(np.setdiff1d(class_0_index, class_0_support), 10).tolist()
    class_1_query = np.random.choice(np.setdiff1d(class_1_index, class_1_support), 10).tolist()
    query_index = class_0_query + class_1_query
    query = traindata[query_index]
    query_title = support['title']
    query_label = support['label']
    query_title = tokenizer(query_title, padding='longest', return_token_type_ids=False, return_tensors='pt')
    query_title = {key: value.to(device) for key, value in query_title.items()}
    query_label = torch.tensor(query_label, device=device)

    return support_title, support_label, query_title, query_label


def train():

    estimator = PrototypicalNetworks().to(device)
    tokenizer = BertTokenizer.from_pretrained('albert_chinese_tiny')
    optimizer = optim.Adam(estimator.parameters(), lr=1e-3)
    traindata = load_from_disk('data/sentence.data')['test']

    aaaa = load_from_disk('data/sentence.data')['train']

    for index in range(1000):
        support_title, support_label, query_title, query_label = generate_inputs(traindata, tokenizer)
        loss = estimator(support_title, support_label, query_title, query_label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(loss.item())

        estimator.build_prototype(traindata, tokenizer)
        y_preds = estimator.get_scores(aaaa, tokenizer)
        from sklearn.metrics import accuracy_score
        print('准确率:', accuracy_score(aaaa['label'], y_preds))


if __name__ == '__main__':
    train()

未经允许不得转载:一亩三分地 » 原型网络(Prototypical Networks)
评论 (0)

6 + 2 =