微调 Bert 实现酒店评论分类 – 模型评估

我们在前面的训练过程中,共产生多个 checkpoint, 分别如下:

checkpoint-10000  checkpoint-18000  checkpoint-24000  checkpoint-4000
checkpoint-12000  checkpoint-2000   checkpoint-26000  checkpoint-6000
checkpoint-14000  checkpoint-20000  checkpoint-28000  checkpoint-8000
checkpoint-16000  checkpoint-22000  checkpoint-30000  checkpoint-final

我们接下来,使用测试集分别在不同的 checkpoint 下去评估下模型的准确率、精度、召回率,还有 f1-score。

1. 模型评估

在我电脑上,使用所有的 12555 测试集评估一次大概需要 2 分钟多,我们共有 16 个模型,评估所有的模型需要 35 分钟左右。

完成模型评估代码如下:

def evaluate(model_path, valid_data):

    # 评估模型
    test_model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)
    test_model.to('cuda')

    with torch.no_grad():

        all_y_true = []
        all_y_pred = []

        def eval(inputs):

            outputs = test_model(**inputs)
            y_true = inputs['labels']
            y_pred = torch.argmax(outputs.logits, dim=-1)

            all_y_true.extend(y_true.cpu().numpy().tolist())
            all_y_pred.extend(y_pred.cpu().numpy().tolist())


        valid_data.map(eval, batched=True, batch_size=32)

        # 评估预测结果
        accuracy = accuracy_score(all_y_true, all_y_pred)
        precis_0 = precision_score(all_y_true, all_y_pred, pos_label=0)
        precis_1 = precision_score(all_y_true, all_y_pred, pos_label=1)
        recall_0 = recall_score(all_y_true, all_y_pred, pos_label=0)
        recall_1 = recall_score(all_y_true, all_y_pred, pos_label=1)
        fscore_0 = f1_score(all_y_true, all_y_pred, pos_label=0)
        fscore_1 = f1_score(all_y_true, all_y_pred, pos_label=1)

        print('------测试集报告------')
        print('accuracy: %.5f' % accuracy)
        print('--------------------')
        print('precis_0: %.5f' % precis_0)
        print('precis_1: %.5f' % precis_1)
        print('--------------------')
        print('recall_0: %.5f' % recall_0)
        print('recall_1: %.5f' % recall_1)
        print('--------------------')
        print('fscore_0: %.5f' % fscore_0)
        print('fscore_1: %.5f' % fscore_1)
        print('--------------------')


def evaluate_models():

    data = load_from_disk('data/senti-dataset')
    data.set_format('pytorch',
                    columns=['labels', 
                             'input_ids', 
                             'token_type_ids', 
                             'attention_mask'],
                    device='cuda')
    valid_data = data['valid'].remove_columns(['review', '__index_level_0__'])
    print('数据总量:', valid_data.num_rows)
    
    # 1. 获得模型路径
    model_path_list = glob.glob('model/*')
    for model_path in model_path_list:
        model_name = model_path.replace('model/', '')
        print('-' * 20 + model_name + '-' * 20)
        evaluate(model_path, valid_data)

评估函数输出结果如下:

accuracy: 0.86874
--------------------
precis_0: 0.86322
precis_1: 0.87424
--------------------
recall_0: 0.87268
recall_1: 0.86488
--------------------
fscore_0: 0.86793
fscore_1: 0.86954
--------------------

--------------------checkpoint-8000--------------------
accuracy: 0.87360
--------------------
precis_0: 0.86443
precis_1: 0.88294
--------------------
recall_0: 0.88268
recall_1: 0.86472
--------------------
fscore_0: 0.87346
fscore_1: 0.87374
--------------------

--------------------checkpoint-10000--------------------
accuracy: 0.87686
--------------------
precis_0: 0.86737
precis_1: 0.88655
--------------------
recall_0: 0.88638
recall_1: 0.86756
--------------------
fscore_0: 0.87677
fscore_1: 0.87695
--------------------

--------------------checkpoint-12000--------------------
accuracy: 0.87814
--------------------
precis_0: 0.87215
precis_1: 0.88412
--------------------
recall_0: 0.88284
recall_1: 0.87354
--------------------
fscore_0: 0.87746
fscore_1: 0.87880
--------------------

--------------------checkpoint-14000--------------------
accuracy: 0.88037
--------------------
precis_0: 0.87308
precis_1: 0.88772
--------------------
recall_0: 0.88687
recall_1: 0.87402
--------------------
fscore_0: 0.87992
fscore_1: 0.88081
--------------------

--------------------checkpoint-16000--------------------
accuracy: 0.88220
--------------------
precis_0: 0.87224
precis_1: 0.89238
--------------------
recall_0: 0.89234
recall_1: 0.87228
--------------------
fscore_0: 0.88218
fscore_1: 0.88222
--------------------

--------------------checkpoint-18000--------------------
accuracy: 0.88284
--------------------
precis_0: 0.87631
precis_1: 0.88939
--------------------
recall_0: 0.88832
recall_1: 0.87748
--------------------
fscore_0: 0.88227
fscore_1: 0.88339
--------------------

--------------------checkpoint-20000--------------------
accuracy: 0.88371
--------------------
precis_0: 0.87677
precis_1: 0.89070
--------------------
recall_0: 0.88977
recall_1: 0.87780
--------------------
fscore_0: 0.88322
fscore_1: 0.88420
--------------------

--------------------checkpoint-22000--------------------
accuracy: 0.88491
--------------------
precis_0: 0.87363
precis_1: 0.89652
--------------------
recall_0: 0.89686
recall_1: 0.87323
--------------------
fscore_0: 0.88509
fscore_1: 0.88472
--------------------

--------------------checkpoint-24000--------------------
accuracy: 0.88586
--------------------
precis_0: 0.87622
precis_1: 0.89570
--------------------
recall_0: 0.89557
recall_1: 0.87638
--------------------
fscore_0: 0.88579
fscore_1: 0.88593
--------------------

--------------------checkpoint-26000--------------------
accuracy: 0.88515
--------------------
precis_0: 0.87322
precis_1: 0.89747
--------------------
recall_0: 0.89799
recall_1: 0.87260
--------------------
fscore_0: 0.88543
fscore_1: 0.88486
--------------------

--------------------checkpoint-28000--------------------
accuracy: 0.88530
--------------------
precis_0: 0.87163
precis_1: 0.89958
--------------------
recall_0: 0.90056
recall_1: 0.87039
--------------------
fscore_0: 0.88586
fscore_1: 0.88474
--------------------

--------------------checkpoint-30000--------------------
accuracy: 0.88546
--------------------
precis_0: 0.87318
precis_1: 0.89818
--------------------
recall_0: 0.89879
recall_1: 0.87244
--------------------
fscore_0: 0.88580
fscore_1: 0.88513
--------------------

--------------------checkpoint-final--------------------
accuracy: 0.88530
--------------------
precis_0: 0.87349
precis_1: 0.89751
--------------------
recall_0: 0.89799
recall_1: 0.87291
--------------------
fscore_0: 0.88557
fscore_1: 0.88504
--------------------

2. 模型预测

在评估完整之后,我们选择在测试集上相对较好的 checkpoint-24000,随机输入一些评论,来查看预测的结果。

def predict(text):

    # 任务模型
    test_model = BertForSequenceClassification.from_pretrained('model/checkpoint-24000', num_labels=2)
    # test_model.to('cuda')
    # 分词器
    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
    # 数据处理
    inputs = clean_data(text)
    # 数据标注
    inputs = tokenizer(inputs, truncation=True, max_length=256, padding='max_length', return_tensors='pt')
    # 模型计算
    outputs = test_model(**inputs)
    # 计算结果
    y_pred = torch.argmax(outputs.logits, dim=-1)
    print('输入:', text)
    print('预测:', '好评' if y_pred.item() == 1 else '差评')


if __name__ == '__main__':
    predict('酒店设施还可以,总体我觉得不错')
    predict('什么垃圾地方,连个矿泉水都没有,下次再也不来了')
    predict('一进酒店,有简单的装饰,中式的简单装修,进入标间,两张床,白色的床单,必备的台灯,还有塑料拖鞋。')
    predict('非常安静,非常赞地理位置:离地铁站很近很方便,前台小姐姐,服务态度很好,很贴心')
    predict('我以前去过其他的酒店,服务特别好,但是在这家酒店就差很多')

预测结果:

输入: 酒店设施还可以,总体我觉得不错
预测: 好评
输入: 什么垃圾地方,连个矿泉水都没有,下次再也不来了
预测: 差评
输入: 一进酒店,有简单的装饰,中式的简单装修,进入标间,两张床,白色的床单,必备的台灯,还有塑料拖鞋。
预测: 差评
输入: 非常安静,非常赞地理位置:离地铁站很近很方便,前台小姐姐,服务态度很好,很贴心
预测: 好评
输入: 我以前去过其他的酒店,服务特别好,但是在这家酒店就差很多
预测: 差评
未经允许不得转载:一亩三分地 » 微调 Bert 实现酒店评论分类 – 模型评估