数据集是中文的酒店评论,共有 50216 + 12555 条评论,前者是训练集,后者是验证集。clean_data 函数是对评论做的一些简单的处理。train_data 的数据对象为:
DatasetDict({ train: Dataset({ features: ['review', '__index_level_0__', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'], num_rows: 50216 }) valid: Dataset({ features: ['review', '__index_level_0__', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'], num_rows: 12555 }) })
我们最终会将数据序列化到 data/senti-dataset 文件中,数据文件结构如下:
senti-dataset/ ├── dataset_dict.json ├── test │ ├── dataset.arrow │ ├── dataset_info.json │ └── state.json ├── train │ ├── dataset.arrow │ ├── dataset_info.json │ └── state.json └── valid ├── dataset.arrow ├── dataset_info.json └── state.json
完整代码如下:
from datasets import load_dataset from datasets import Dataset import pandas as pd import zhconv import re from transformers import BertTokenizer def clean_data(inputs: str): # 繁体转简体 inputs = zhconv.convert(inputs, 'zh-hans') # 大写转小写 inputs = inputs.lower() # 删除 "免费注册 网站导航..." start = inputs.find('免费注册') if start != -1: inputs = inputs[:start] # 去除除了中文、数字、字母、逗号、句号、问号 # inputs = inputs.replace(',', ',') # inputs = inputs.replace('!', '!') # inputs = inputs.replace('?', '?') # inputs = inputs.replace('.', '。') inputs = re.sub(r'[^\u4e00-\u9fa50-9a-z]', ' ', inputs) # 替换连续重复 inputs = re.sub(r'(.)\1+', r'\1', inputs) # 去除多余空格 inputs = ' '.join(inputs.split()) return inputs def preprocess(): tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') train_data = pd.read_csv('data/online_shopping_10_cats.csv') train_data = train_data[['label', 'review']] train_data = train_data.dropna() # DataFrame 转换为 Dataset 对象 train_data = Dataset.from_pandas(train_data) # 处理数据 # 注意: Trainer 要求输入的数据带有 labels 标签 train_data = train_data.map(lambda x: {'labels': x['label'], 'review': clean_data(x['review'])}) # 过滤空数据 train_data = train_data.filter(lambda x: len(x['review']) != 0) # map 函数会将返回的字典并到原来的字典中 train_data = train_data.map( lambda x: tokenizer(x['review'], truncation=True, padding='max_length', max_length=256), batched=True) # 删除某些列 train_data = train_data.remove_columns(['label']) # 分割数据集 train_data = train_data.train_test_split(test_size=0.2) train_data['valid'] = train_data.pop('test') print(train_data) # 存储数据 train_data.save_to_disk('data/senti-dataset') if __name__ == '__main__': preprocess()