该代码实现了一个基于 RNN 的情感分析模型。主要功能包括:
- 定义模型:包含
Embedding
、RNN
、Linear
层,用于文本分类。 - 处理输入:使用
pack_padded_sequence
处理变长序列,避免PAD
影响计算。 - 模型存储与加载:支持
pickle
序列化保存和恢复。 - 测试示例:对文本进行编码、排序并输入模型,输出分类结果。
注意:下面代码中 nn.RNN 可以直接替换为 nn.GRU、nn.LSTM,更容易训练。
创建 estimator.py
文件并添加如下代码:
import torch.nn as nn import torch from torch.nn.utils.rnn import pack_padded_sequence from torch.nn.utils.rnn import pad_packed_sequence from tokenizer import Tokenzier import pickle class SentimentAnalysis(nn.Module): def __init__(self, vocab_size=0, num_labels=2, padding_idx=0): super(SentimentAnalysis, self).__init__() self.vocab_size = vocab_size self.num_labels = num_labels self.padding_idx = padding_idx self.ebd = nn.Embedding(num_embeddings=vocab_size, embedding_dim=128, padding_idx=padding_idx) self.rnn = nn.RNN(input_size=128, hidden_size=256, batch_first=True) self.out = nn.Linear(in_features=256, out_features=num_labels) def __call__(self, input_ids, batch_length): inputs = self.ebd(input_ids) # 将带有 pad 的批次输入转换为 PackedSequence 格式,避免 pad 参与 rnn 计算 inputs = pack_padded_sequence(inputs, lengths=batch_length, batch_first=True, enforce_sorted=True) output, hn = self.rnn(inputs) # 将 PackedSequence 转换为带有 pad 的批次数据 # output, lens = pad_packed_sequence(output, batch_first=True, padding_value=self.ebd.padding_idx) # logging.info(f'output: {output.shape}, lens: {lens}, hn: {hn.shape}') inputs = self.out(hn.squeeze()) return inputs def save(self, path): init_param = {'vocab_size': self.vocab_size, 'num_labels': self.num_labels, 'padding_idx': self.padding_idx} parameters = self.state_dict() save_data = {'init_param': init_param, 'parameters': parameters} pickle.dump(save_data, open(path, 'wb')) @classmethod def load(cls, path): params = pickle.load(open(path, 'rb')) estimator = SentimentAnalysis(**(params['init_param'])) estimator.load_state_dict(params['parameters']) return estimator def demo(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') tokenizer = Tokenzier() batch_inputs, batch_length = tokenizer.encode(['梦想有多大,舞台就有多大![鼓掌]', '[花心][鼓掌]//@小懒猫Melody2011: [春暖花开]']) estimator = SentimentAnalysis(vocab_size=tokenizer.get_vocab_size(), num_labels=2, padding_idx=tokenizer.pad).to(device) # 对批次输入根据长度降序排列(注意: 标签也需要相应排序) sorted_index = torch.argsort(batch_length, descending=True) batch_inputs = batch_inputs[sorted_index].to(device) batch_length = batch_length[sorted_index] outputs = estimator(batch_inputs, batch_length) print(outputs) if __name__ == '__main__': demo()