该代码实现了一个基于 RNN 的情感分析模型。主要功能包括:
- 定义模型:包含
Embedding、RNN、Linear层,用于文本分类。 - 处理输入:使用
pack_padded_sequence处理变长序列,避免PAD影响计算。 - 模型存储与加载:支持
pickle序列化保存和恢复。 - 测试示例:对文本进行编码、排序并输入模型,输出分类结果。
注意:下面代码中 nn.RNN 可以直接替换为 nn.GRU、nn.LSTM,更容易训练。
创建 estimator.py 文件并添加如下代码:
import torch.nn as nn
import torch
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
from tokenizer import Tokenzier
import pickle
class SentimentAnalysis(nn.Module):
def __init__(self, vocab_size=0, num_labels=2, padding_idx=0):
super(SentimentAnalysis, self).__init__()
self.vocab_size = vocab_size
self.num_labels = num_labels
self.padding_idx = padding_idx
self.ebd = nn.Embedding(num_embeddings=vocab_size, embedding_dim=128, padding_idx=padding_idx)
self.rnn = nn.RNN(input_size=128, hidden_size=256, batch_first=True)
self.out = nn.Linear(in_features=256, out_features=num_labels)
def __call__(self, input_ids, batch_length):
inputs = self.ebd(input_ids)
# 将带有 pad 的批次输入转换为 PackedSequence 格式,避免 pad 参与 rnn 计算
inputs = pack_padded_sequence(inputs, lengths=batch_length, batch_first=True, enforce_sorted=True)
output, hn = self.rnn(inputs)
# 将 PackedSequence 转换为带有 pad 的批次数据
# output, lens = pad_packed_sequence(output, batch_first=True, padding_value=self.ebd.padding_idx)
# logging.info(f'output: {output.shape}, lens: {lens}, hn: {hn.shape}')
inputs = self.out(hn.squeeze())
return inputs
def save(self, path):
init_param = {'vocab_size': self.vocab_size, 'num_labels': self.num_labels, 'padding_idx': self.padding_idx}
parameters = self.state_dict()
save_data = {'init_param': init_param, 'parameters': parameters}
pickle.dump(save_data, open(path, 'wb'))
@classmethod
def load(cls, path):
params = pickle.load(open(path, 'rb'))
estimator = SentimentAnalysis(**(params['init_param']))
estimator.load_state_dict(params['parameters'])
return estimator
def demo():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = Tokenzier()
batch_inputs, batch_length = tokenizer.encode(['梦想有多大,舞台就有多大![鼓掌]', '[花心][鼓掌]//@小懒猫Melody2011: [春暖花开]'])
estimator = SentimentAnalysis(vocab_size=tokenizer.get_vocab_size(), num_labels=2, padding_idx=tokenizer.pad).to(device)
# 对批次输入根据长度降序排列(注意: 标签也需要相应排序)
sorted_index = torch.argsort(batch_length, descending=True)
batch_inputs = batch_inputs[sorted_index].to(device)
batch_length = batch_length[sorted_index]
outputs = estimator(batch_inputs, batch_length)
print(outputs)
if __name__ == '__main__':
demo()

冀公网安备13050302001966号