原型网络是一种能够解决小样本学习问题。其主要思想是将样本映射到一个低维空间中,并在该空间中计算每个类别的原型,然后通过计算测试样本和每个原型之间的距离来进行分类。
Paper:https://arxiv.org/pdf/1703.05175.pdf
训练过程:
- 构造输入:随机从小样本的数据集中构造出两个数据集,我们暂且把其中一个叫做支持集、一个叫做查询集。例如:我们有的小数据集中存在 5 个类别,我们从 5 个类别中分别采样 10 个样本,这就是50个样本组成的支持集。再从这 5 个类别剩下的样本中,每个类别采样 15 个样本作为查询集;
- 类别原型:将 5 个类别的支持集的 50 个样本送入到编码器得到每个样本的编码表示,再分别计算每个类别的平均编码表示作为对应类别的原型表示;:将训练集中按照类别划分,计算每个类别的平均编码表示,作为该类别的原型。
- 计算距离:将查询集中样本送入到编码器得到编码表示,并分别计算与各个原型的距离(欧式距离、余弦相似度);
- 损失计算:将得到的距离转换为概率分布表示,使用交叉熵来计算这一次输入的损失;
- 重复 1-5 步骤,直至模型收敛。
网络训练完成之后,使用所有的训练样本的平均编码表示作为每个类别的原型表示。
预测过程:
- 对输入的样本进行编码
- 计算样本和每个原型的距离,并将距离转换为概率分布
- 将样本归类到概率最大的类别
示例代码:
import torch.nn as nn import torch from datasets import load_from_disk from transformers import AlbertModel from torch.utils.data import DataLoader from collections import Counter import numpy as np from transformers import BertTokenizer import torch.nn.functional as F import torch.optim as optim import datasets datasets.disable_progress_bar() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class PrototypicalNetworks(nn.Module): def __init__(self): super(PrototypicalNetworks, self).__init__() self.encoder = AlbertModel.from_pretrained('albert_chinese_tiny') self.prototype = None def get_prototype(self, support_title, support_label): support_outputs = self.encoder(**support_title) sentence_embeddings = support_outputs.last_hidden_state[:, 0] support_embeddings = [] for label in range(2): class_embeding = sentence_embeddings[support_label == label] support_embeddings.append(torch.mean(class_embeding, dim=0)) support_embeddings = torch.stack(support_embeddings) # support_embeddings = F.normalize(support_embeddings, p=2, dim=-1) return support_embeddings @torch.no_grad() def build_prototype(self, trainset, tokenizer): self.eval() embeddings = [] labels = [] def collate_function(batch_data): title = batch_data['title'] label = batch_data['label'] title = tokenizer(title, padding='longest', return_token_type_ids=False, return_tensors='pt') title = {key: value.to(device) for key, value in title.items()} outputs = self.encoder(**title) embeddings.extend(outputs.last_hidden_state[:, 0]) labels.extend(label) trainset.map(collate_function, batched=True, batch_size=32) embeddings = torch.stack(embeddings) a = torch.mean(embeddings[torch.tensor(labels) == 0], dim=0) b = torch.mean(embeddings[torch.tensor(labels) == 1], dim=0) self.prototype = torch.stack([a, b]) print('AAAAAA:', self.prototype.shape) @torch.no_grad() def get_scores(self, trainset, tokenizer): self.eval() embeddings = [] labels = [] def collate_function(batch_data): title = batch_data['title'] label = batch_data['label'] title = tokenizer(title, padding='longest', return_token_type_ids=False, return_tensors='pt') title = {key: value.to(device) for key, value in title.items()} outputs = self.encoder(**title) embeddings.extend(outputs.last_hidden_state[:, 0]) labels.extend(label) trainset.map(collate_function, batched=True, batch_size=16) embeddings = torch.stack(embeddings) y_preds = [] for embedding in embeddings: distance = F.pairwise_distance(embedding.expand(self.prototype.size()), self.prototype) y_pred = torch.argmax(distance) y_preds.append(y_pred.item()) return y_preds def get_query(self, query_title): query_outputs = self.encoder(**query_title) query_embeddings = query_outputs.last_hidden_state[:, 0] # query_embeddings = F.normalize(sentence_embeddings, p=2, dim=-1) return query_embeddings def get_loss(self, prototype, query, query_label): distances = [] for q in query: distance = F.pairwise_distance(q.expand(prototype.size()), prototype) distances.append(distance) distances = torch.stack(distances) loss = F.cross_entropy(distances, query_label) return loss def forward(self, support_title, support_label, query_title, query_label): self.train() prototype = self.get_prototype(support_title, support_label) query = self.get_query(query_title) loss = self.get_loss(prototype, query, query_label) return loss def generate_inputs(traindata, tokenizer): # 计算两个类别的样本索引 class_0_index = torch.argwhere(torch.tensor(traindata['label']) == 0).squeeze().tolist() class_1_index = torch.argwhere(torch.tensor(traindata['label']) == 1).squeeze().tolist() class_0_support = np.random.choice(class_0_index, 10).tolist() class_1_support = np.random.choice(class_1_index, 10).tolist() support_index = class_0_support + class_1_support support = traindata[support_index] support_title = support['title'] support_label = support['label'] support_title = tokenizer(support_title, padding='longest', return_token_type_ids=False, return_tensors='pt') support_title = {key: value.to(device) for key, value in support_title.items()} support_label = torch.tensor(support_label, device=device) class_0_query = np.random.choice(np.setdiff1d(class_0_index, class_0_support), 10).tolist() class_1_query = np.random.choice(np.setdiff1d(class_1_index, class_1_support), 10).tolist() query_index = class_0_query + class_1_query query = traindata[query_index] query_title = support['title'] query_label = support['label'] query_title = tokenizer(query_title, padding='longest', return_token_type_ids=False, return_tensors='pt') query_title = {key: value.to(device) for key, value in query_title.items()} query_label = torch.tensor(query_label, device=device) return support_title, support_label, query_title, query_label def train(): estimator = PrototypicalNetworks().to(device) tokenizer = BertTokenizer.from_pretrained('albert_chinese_tiny') optimizer = optim.Adam(estimator.parameters(), lr=1e-3) traindata = load_from_disk('data/sentence.data')['test'] aaaa = load_from_disk('data/sentence.data')['train'] for index in range(1000): support_title, support_label, query_title, query_label = generate_inputs(traindata, tokenizer) loss = estimator(support_title, support_label, query_title, query_label) optimizer.zero_grad() loss.backward() optimizer.step() print(loss.item()) estimator.build_prototype(traindata, tokenizer) y_preds = estimator.get_scores(aaaa, tokenizer) from sklearn.metrics import accuracy_score print('准确率:', accuracy_score(aaaa['label'], y_preds)) if __name__ == '__main__': train()