《文本情感分析》(五)模型搭建

该代码实现了一个基于 RNN 的情感分析模型。主要功能包括:

  1. 定义模型:包含 EmbeddingRNNLinear 层,用于文本分类。
  2. 处理输入:使用 pack_padded_sequence 处理变长序列,避免 PAD 影响计算。
  3. 模型存储与加载:支持 pickle 序列化保存和恢复。
  4. 测试示例:对文本进行编码、排序并输入模型,输出分类结果。

注意:下面代码中 nn.RNN 可以直接替换为 nn.GRU、nn.LSTM,更容易训练。

创建 estimator.py 文件并添加如下代码:

import torch.nn as nn
import torch
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from tokenizer import Tokenzier
import pickle


class SentimentAnalysis(nn.Module):

    def __init__(self, vocab_size=0, num_labels=2, padding_idx=0):
        super(SentimentAnalysis, self).__init__()
        self.vocab_size = vocab_size
        self.num_labels = num_labels
        self.padding_idx = padding_idx
        self.ebd = nn.Embedding(num_embeddings=vocab_size, embedding_dim=128, padding_idx=padding_idx)
        self.rnn = nn.RNN(input_size=128, hidden_size=256, batch_first=True)
        self.out = nn.Linear(in_features=256, out_features=num_labels)

    def __call__(self, input_ids, batch_length):
        inputs = self.ebd(input_ids)
        # 将带有 pad 的批次输入转换为 PackedSequence 格式,避免 pad 参与 rnn 计算
        inputs = pack_padded_sequence(inputs, lengths=batch_length, batch_first=True, enforce_sorted=True)
        output, hn  = self.rnn(inputs)
        # 将 PackedSequence 转换为带有 pad 的批次数据
        # output, lens = pad_packed_sequence(output, batch_first=True, padding_value=self.ebd.padding_idx)
        # logging.info(f'output: {output.shape}, lens: {lens}, hn: {hn.shape}')

        inputs = self.out(hn.squeeze())

        return inputs

    def save(self, path):
        init_param = {'vocab_size': self.vocab_size, 'num_labels': self.num_labels, 'padding_idx': self.padding_idx}
        parameters = self.state_dict()
        save_data = {'init_param': init_param, 'parameters': parameters}
        pickle.dump(save_data, open(path, 'wb'))

    @classmethod
    def load(cls, path):
        params = pickle.load(open(path, 'rb'))
        estimator = SentimentAnalysis(**(params['init_param']))
        estimator.load_state_dict(params['parameters'])
        return estimator


def demo():

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = Tokenzier()
    batch_inputs, batch_length = tokenizer.encode(['梦想有多大,舞台就有多大![鼓掌]', '[花心][鼓掌]//@小懒猫Melody2011: [春暖花开]'])
    estimator = SentimentAnalysis(vocab_size=tokenizer.get_vocab_size(), num_labels=2, padding_idx=tokenizer.pad).to(device)

    # 对批次输入根据长度降序排列(注意: 标签也需要相应排序)
    sorted_index = torch.argsort(batch_length, descending=True)
    batch_inputs = batch_inputs[sorted_index].to(device)
    batch_length = batch_length[sorted_index]

    outputs = estimator(batch_inputs, batch_length)
    print(outputs)


if __name__ == '__main__':
    demo()
未经允许不得转载:一亩三分地 » 《文本情感分析》(五)模型搭建
评论 (0)

8 + 2 =