基于 OpenAI 的模型进行分类任务微调,大致需要以下几个步骤:
- 准备数据:这一步先自行对文本进行预处理,然后使用 OpenAI 工具对文本内容进行二次处理
- 微调模型:将准备好的数据上传,并指定预训练模型进行微调
- 使用模型:使用微调后的模型做情感分类
Doc:https://platform.openai.com/docs/api-reference/fine-tunes
1. 准备数据
import pandas as pd def prepare_data(): data = pd.read_csv('data/comments.csv') class_0_data = data[data['label'] == 0] class_1_data = data[data['label'] == 1] new_data = pd.concat([class_0_data[:10], class_1_data[:10]]) new_data.index = pd.Series(list(range(20))) new_data['label'] = pd.Series(['消极'] * 10 + ['积极'] * 10) new_data.columns = ['completion', 'prompt'] new_data.to_json('data/fine_tuning.json', orient='records', lines=True) if __name__ == '__main__': prepare_data()
原始数据内容为:
label review 0 1 距离川沙公路较近,但是公交指示不对,如果是"蔡陆线"的话,会非常麻烦.建议用别的路线.房间较... 1 1 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错! 2 1 早餐太差,无论去多少人,那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。 3 1 宾馆在小街道上,不大好找,但还好北京热心同胞很多~宾馆设施跟介绍的差不多,房间很小,确实挺小... 4 1 CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风 ... ... ... 7761 0 尼斯酒店的几大特点:噪音大、环境差、配置低、服务效率低。如:1、隔壁歌厅的声音闹至午夜3点许... 7762 0 盐城来了很多次,第一次住盐阜宾馆,我的确很失望整个墙壁黑咕隆咚的,好像被烟熏过一样家具非常的... 7763 0 看照片觉得还挺不错的,又是4星级的,但入住以后除了后悔没有别的,房间挺大但空空的,早餐是有但... 7764 0 我们去盐城的时候那里的最低气温只有4度,晚上冷得要死,居然还不开空调,投诉到酒店客房部,得到... 7765 0 说实在的我很失望,之前看了其他人的点评后觉得还可以才去的,结果让我们大跌眼镜。我想这家酒店以... [7766 rows x 2 columns]
经过 prepare_data 函数处理之后的内容为:
completion prompt 0 消极 标准间太差房间还不如3星的而且设施非常陈旧.建议酒店把老的标准间从新改善. 1 消极 服务态度极其差,前台接待好象没有受过培训,连基本的礼貌都不懂,竟然同时接待几个客人;大堂副理... 2 消极 地理位置还不错,到哪里都比较方便,但是服务不象是豪生集团管理的,比较差。下午睡了一觉并洗了一... 3 消极 1。我住的是靠马路的标准间。房间内设施简陋,并且的房间玻璃窗户外还有一层幕墙玻璃,而且不能打... 4 消极 我这次是第5次住在长春的雁鸣湖大酒店。昨晚夜里停电。深夜我睡着了。我的钱包被内贼进入我的房间... 5 消极 前台checkin花了20分钟,checkout25分钟,这是服务态度和没有做到位。信用卡刷... 6 消极 有或者很少房!梯部不吸,但是有一些吸者仍然有服!我是不抽的人,成二手的受害者!(中13人口中... 7 消极 酒店服务态度极差,设施很差,建议还是不要到那儿去。 8 消极 我3.6预定好的180的标间,当我到的时候竟然说有会议房间满了,我订的房间没有了,太不讲信誉... 9 消极 房间的环境非常差,而且房间还不隔音,住的不舒服。 10 积极 距离川沙公路较近,但是公交指示不对,如果是"蔡陆线"的话,会非常麻烦.建议用别的路线.房间较... 11 积极 商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错! 12 积极 早餐太差,无论去多少人,那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。 13 积极 宾馆在小街道上,不大好找,但还好北京热心同胞很多~宾馆设施跟介绍的差不多,房间很小,确实挺小... 14 积极 CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风 15 积极 总的来说,这样的酒店配这样的价格还算可以,希望他赶快装修,给我的客人留些好的印象 16 积极 价格比比较不错的酒店。这次免费升级了,感谢前台服务员。房子还好,地毯是新的,比上次的好些。早... 17 积极 不错,在同等档次酒店中应该是值得推荐的! 18 积极 入住丽晶,感觉很好。因为是新酒店,的确有淡淡的油漆味,房间内较新。房间大小合适,卫生间设备齐... 19 积极 1。酒店比较新,装潢和设施还不错,只是房间有些油漆味。2。早餐还可以,只是品种不是很多。3。...
接下来,使用 openai 的数据预处理工具,将上面的数据处理成适合模型微调的数据。 CD 到 prepare_data 函数存储数据的目录,执行如下命令:
openai tools fine_tunes.prepare_data -f fine_tuning.json -q
此时会输出内容如下:
Analyzing... - Your JSON file appears to be in a JSONL format. Your file will be converted to JSONL format - Your file contains 20 prompt-completion pairs. In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples - Based on your data it seems like you're trying to fine-tune a model for classification - For classification, we recommend you try one of the faster and cheaper models, such as `ada` - For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training - Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty - The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details Based on the analysis we will perform the following actions: - [Necessary] Your format `JSON` will be converted to `JSONL` - [Recommended] Add a suffix separator ` ->` to all prompts [Y/n]: Y - [Recommended] Add a whitespace character to the beginning of the completion [Y/n]: Y - [Recommended] Would you like to split into training and validation set? [Y/n]: Y Your data will be written to a new JSONL file. Proceed [Y/n]: Y Wrote modified files to `fine_tuning_prepared_train.jsonl` and `fine_tuning_prepared_valid.jsonl` Feel free to take a look! Now use that file when fine-tuning: > openai api fine_tunes.create -t "fine_tuning_prepared_train.jsonl" -v "fine_tuning_prepared_valid.jsonl" --compute_classification_metrics --classification_positive_class " 消极" After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string ` ->` for the model to start generating completions, rather than continuing with the prompt. Make sure to include `stop=["极"]` so that the generated texts ends at the expected place. Once your model starts training, it'll approximately take 2.81 minutes to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.
上面有以下几项大致的内容:
- Your file contains 20 prompt-completion pairs. In general, we recommend having at least a few hundred examples. We’ve found that performance tends to linearly increase for every doubling of the number of examples。我们的微调使用的样本太少了,输出日志中建议我们至少几百个样本,并且告诉我们训练样本越多,模型的性能越好。
- Based on your data it seems like you’re trying to fine-tune a model for classification。For classification, we recommend you try one of the faster and cheaper models, such as
ada
。这里提到说,根据我们的数据,就知道我们要做分类任务,并强烈建议我们使用更快、并且更便宜的 ada 模型作为微调的预训练模型。 - Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. 这里是说,我们的训练数据中在 prompts 之后并没有添加一个分割符号,通过这个符号可以告诉模型开始预测 completion。所以,该工具会在我们的文本之后添加一个特殊的标记,至于这个标记具体是什么,不同的 openai 版本可能有所不同,一会可以看下生成的数据集即可。
最终,openai 的数据处理工具会进行一些如下的操作:
- [Necessary] Your format
JSON
will be converted toJSONL
- [Recommended] Add a suffix separator
->
to all prompts [Y/n]: Y 增加 -> 在 prompts 之后作为分隔符 - [Recommended] Add a whitespace character to the beginning of the completion [Y/n]: Y 在 completion 之前增加空格,表示 prompts 和 completion 之间的间隔
- [Recommended] Would you like to split into training and validation set? [Y/n]: Y 将数据集分割为训练集和测试集
最终,会在数据目录下生成两个文件:
- fine_tuning_prepared_train.jsonl 训练集
- fine_tuning_prepared_valid.jsonl 验证集
我们看下验证集的数据内容:
{"prompt":"有或者很少房!梯部不吸,但是有一些吸者仍然有服!我是不抽的人,成二手的受害者!(中13人口中,民只有3.2.不到1\/4!!!)看到的民,自好?. ->","completion":" 消极"} {"prompt":"酒店服务态度极差,设施很差,建议还是不要到那儿去。 ->","completion":" 消极"} {"prompt":"距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较为简单. ->","completion":" 积极"} {"prompt":"CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风 ->","completion":" 积极"}
可以看到,在每个 prompt 后面增加一个 “->” 符号,一方面表示 prompt 结束,另一方面表示开始生成 completion。同时在每个 competion 之前增加一个空格,用于和 prompt 隔开。需要注意的是,如果我们要在该模型进行分类的话,别忘了在新的输入 prompt 之后加上 “->” 分隔符。
2. 微调模型
在上面输出的内容中,有提到,如果我们要进行微调的话,使用下面的命令:
openai api fine_tunes.create -t "fine_tuning_prepared_train.jsonl" -v "fine_tuning_prepared_valid.jsonl" --compute_classification_metrics --classification_positive_class " 消极"
- -t 指定训练数据集
- -v 指定验证数据集
- –compute_classification_metrics 表示对测试集进行评估
- –classification_positive_class 表示数据集中的正样本标签,我们这里把它生成 ” 消极” 修改为 ” 积极”
最终执行命令如下(这种方式需要在系统环境变量中设置 OPENAI_API_KEY):
openai api fine_tunes.create -t "fine_tuning_prepared_train.jsonl" -v "fine_tuning_prepared_valid.jsonl" --compute_classification_metrics --classification_positive_class " 积极"
其他的一些超参数:
- model:要微调的基本模型的名称。您可以选择 “ada”、”babbage”、”curie” 或者 “davinci”。
- n_epochs- 默认为 4。
- batch_size- 默认为训练集中样本数的 ~0.2%,上限为 256。通常较大的批量大小更适合较大的数据集。
- learning_rate_multiplier- 默认为 0.05、0.1 或 0.2
命令执行之后,输出内容如下:
Upload progress: 100%|███████████████████████████████████████████████████████████████████████████████████| 6.39k/6.39k [00:00<00:00, 7.65Mit/s] Uploaded file from fine_tuning_prepared_train.jsonl: file-Euf0oLwGX1hx4nbHZu9o8LgC Upload progress: 100%|███████████████████████████████████████████████████████████████████████████████████████| 636/636 [00:00<00:00, 1.06Mit/s] Uploaded file from fine_tuning_prepared_valid.jsonl: file-gu3kbiOr7yCKGgN6YLsYg7fY Created fine-tune: ft-7xhbzqbkIxUiCRBvsWEVjQyE Streaming events until fine-tuning is complete... (Ctrl-C will interrupt the stream, but not cancel the fine-tune) [2023-04-09 17:55:48] Created fine-tune: ft-7xhbzqbkIxUiCRBvsWEVjQyE [2023-04-09 17:55:59] Fine-tune costs $0.05 [2023-04-09 17:55:59] Fine-tune enqueued. Queue number: 0 [2023-04-09 17:56:01] Fine-tune started
微调模型是需要花钱的,上面输出内容显示出此次训练花费了 0.05 美刀。使用下面命令,查看我们微调的模型:
openai api fine_tunes.list
输出内容如下:
{ "data": [ { "created_at": 1681034148, "fine_tuned_model": "curie:ft-personal-2023-04-09-09-57-44", "hyperparams": { "batch_size": 1, "classification_positive_class": " \u79ef\u6781", "compute_classification_metrics": true, "learning_rate_multiplier": 0.1, "n_epochs": 4, "prompt_loss_weight": 0.01 }, "id": "ft-7xhbzqbkIxUiCRBvsWEVjQyE", "model": "curie", "object": "fine-tune", "organization_id": "org-CYfr2zckOKWqMlBgQHaL8rzp", "result_files": [ { "bytes": 4214, "created_at": 1681034264, "filename": "compiled_results.csv", "id": "file-9djdWGJeysmtpT7uYRpzuvkm", "object": "file", "purpose": "fine-tune-results", "status": "processed", "status_details": null } ], "status": "succeeded", "training_files": [ { "bytes": 6393, "created_at": 1681034146, "filename": "fine_tuning_prepared_train.jsonl", "id": "file-Euf0oLwGX1hx4nbHZu9o8LgC", "object": "file", "purpose": "fine-tune", "status": "processed", "status_details": null } ], "updated_at": 1681034265, "validation_files": [ { "bytes": 636, "created_at": 1681034148, "filename": "fine_tuning_prepared_valid.jsonl", "id": "file-gu3kbiOr7yCKGgN6YLsYg7fY", "object": "file", "purpose": "fine-tune", "status": "processed", "status_details": null } ] } ], "object": "list" }
从输出的内容,我们可以得到如下信息:
- “fine_tuned_model”: “curie:ft-personal-2023-04-09-09-57-44″:这个就是我们使用的微调后的模型ID
- “hyperparams”:微调时,我们使用的超参数
- “id”: “ft-7xhbzqbkIxUiCRBvsWEVjQyE”:模型的编号,用于获得模型相关信息
- “model”: “curie”:我们微调使用的基础模型为 curie,这是默认使用的模型,我们也可以微调时通过 -m 来指定 ada
- “status”: “succeeded”:表示微调的状态,这里显示微调成功或者完成
- …其他
使用下面的命令,将训练结果下载到本地:
# -i 用来指定模型的id openai api fine_tunes.results -i ft-7xhbzqbkIxUiCRBvsWEVjQyE > result.csv
在我的本地生成了 result.csv 文件,我们看下该文件的内容,下面内容可以保存到 csv 文件中查看:
step,elapsed_tokens,elapsed_examples,training_loss,training_sequence_accuracy,training_token_accuracy,validation_loss,validation_sequence_accuracy,validation_token_accuracy,classification/accuracy,classification/precision,classification/recall,classification/auroc,classification/auprc,classification/f1.0 1,137,1,0.11297712909037737,0.0,0.5,0.1965679380938286,0.0,0.5,,,,,, 2,1138,2,0.026734490502419668,0.0,0.5,,,,,,,,, 3,1547,3,0.041407643495439544,0.0,0.6666666666666666,,,,,,,,, 4,1604,4,0.20218953214144544,0.0,0.3333333333333333,,,,,,,,, 5,1701,5,0.11237559589742213,0.0,0.5,,,,,,,,, 6,1846,6,0.07346868902813872,0.0,0.5,,,,,,,,, 7,1975,7,0.07653379087575103,0.0,0.8333333333333334,,,,,,,,, 8,2992,8,0.021339937947021315,0.0,0.8333333333333334,,,,,,,,, 9,3241,9,0.03946435204677875,0.0,0.8333333333333334,0.11924629365356243,0.0,0.6666666666666666,,,,,, 10,3330,10,0.07196011489376213,0.0,0.6666666666666666,,,,,,,,, 11,3395,11,0.07033659532143474,0.0,0.8333333333333334,,,,,,,,, 12,3612,12,0.03595350411543043,0.0,0.8333333333333334,,,,,,,,, 13,3741,13,0.02119638757410091,1.0,1.0,,,,,,,,, 14,3830,14,0.05899414582442825,0.0,0.8333333333333334,,,,,,,,, 15,3887,15,0.05293363639834362,0.0,0.8333333333333334,,,,,,,,, 16,4272,16,0.01709485127953574,1.0,1.0,,,,,,,,, 17,5289,17,0.015280593367766031,1.0,1.0,0.03453893841861048,1.0,1.0,0.5,0.0,0.0,0.25,0.3333333333333333,0.0 18,6290,18,0.01659894965089464,0.0,0.8333333333333334,,,,,,,,, 19,6387,19,0.08227947945942217,0.0,0.8333333333333334,,,,,,,,, 20,6452,20,0.0673786170370725,0.0,0.8333333333333334,,,,,,,,, 21,6589,21,0.01852450909944652,1.0,1.0,,,,,,,,, 22,6838,22,0.01672595546186104,1.0,1.0,,,,,,,,, 23,6967,23,0.01842574403242967,1.0,1.0,,,,,,,,, 24,7352,24,0.018070253951559655,1.0,1.0,,,,,,,,, 25,7497,25,0.018238978561170335,0.0,0.8333333333333334,0.026812056690901857,0.0,0.8333333333333334,,,,,, 26,7554,26,0.018804501375033807,1.0,1.0,,,,,,,,, 27,7611,27,0.018643728179655778,1.0,1.0,,,,,,,,, 28,8020,28,0.021471951293555423,0.0,0.8333333333333334,,,,,,,,, 29,8149,29,0.01589545032287536,1.0,1.0,,,,,,,,, 30,8238,30,0.02560983169629191,1.0,1.0,,,,,,,,, 31,8327,31,0.016086113855492766,1.0,1.0,,,,,,,,, 32,8544,32,0.025426828228896272,0.0,0.8333333333333334,,,,,,,,, 33,8641,33,0.016373053608218033,1.0,1.0,0.02478175366484695,1.0,1.0,,,,,, 34,8698,34,0.012001360776212056,1.0,1.0,,,,,,,,, 35,8763,35,0.018112453049790214,1.0,1.0,,,,,,,,, 36,8980,36,0.02059483739298285,0.0,0.8333333333333334,,,,,,,,, 37,9109,37,0.020320957661593295,1.0,1.0,,,,,,,,, 38,10110,38,0.01485486638983113,1.0,1.0,,,,0.75,0.6666666666666666,1.0,0.75,0.7916666666666666,0.8 39,10167,39,0.016590197144024585,1.0,1.0,,,,,,,,, 40,10256,40,0.019145875603147972,1.0,1.0,,,,,,,,, 41,10393,41,0.01974262404774302,1.0,1.0,0.03868247642918245,1.0,1.0,,,,,, 42,10482,42,0.017745215463898456,1.0,1.0,,,,,,,,, 43,10867,43,0.01643800483749426,1.0,1.0,,,,,,,,, 44,11276,44,0.016559487510540653,1.0,1.0,,,,,,,,, 45,12293,45,0.014681880749237009,1.0,1.0,,,,,,,,, 46,12542,46,0.015384798510621219,1.0,1.0,,,,,,,,, 47,12687,47,0.013482770004832167,1.0,1.0,,,,,,,,, 48,12816,48,0.016251353808475387,1.0,1.0,,,,,,,,, 49,12945,49,0.015636948876930602,1.0,1.0,0.011417988476832727,1.0,1.0,,,,,, 50,13194,50,0.015154518923057117,1.0,1.0,,,,,,,,, 51,13603,51,0.01619545711472625,1.0,1.0,,,,,,,,, 52,13668,52,0.01764821404665394,1.0,1.0,,,,,,,,, 53,13757,53,0.017546420559014668,1.0,1.0,,,,,,,,, 54,14142,54,0.01629021499741482,1.0,1.0,,,,,,,,, 55,15159,55,0.014517653008841998,1.0,1.0,,,,0.5,0.0,0.0,0.75,0.7916666666666666,0.0 56,15376,56,0.015640235622392317,1.0,1.0,,,,,,,,, 57,15473,57,0.017093179406629962,1.0,1.0,0.037967130073333646,0.0,0.8333333333333334,,,,,, 58,15602,58,0.015698925190150044,1.0,1.0,,,,,,,,, 59,16603,59,0.014752162144404994,1.0,1.0,,,,,,,,, 60,16660,60,0.012103026345412493,1.0,1.0,,,,,,,,, 61,16805,61,0.013015271683916614,1.0,1.0,,,,,,,,, 62,16894,62,0.01824016018865998,1.0,1.0,,,,,,,,, 63,17031,63,0.01693082119428943,1.0,1.0,,,,,,,,, 64,17088,64,0.014343387357669275,1.0,1.0,,,,,,,,, 65,17145,65,0.014343387357669275,1.0,1.0,0.037712673740637276,0.0,0.8333333333333334,,,,,, 66,17362,66,0.015630679206655895,1.0,1.0,,,,0.5,0.0,0.0,0.75,0.7916666666666666,0.0
3. 使用模型
注意:在输入新的内容时,别忘了在 prompt 后面加上 “->” 就行了,加载微调的模型,可以直接使用 openai 的 API,也可以使用 requests 模块发送请求。示例代码如下:
import requests import json import openai def use_fine_tune_model_01(): model = "curie:ft-personal-2023-04-09-09-57-44" sentence = '酒店服务态度极差,设施很差,建议还是不要到那儿去。 ->' request_url = 'https://api.openai.com/v1/completions' headers = {'Content-Type': 'application/json', 'Authorization': 'Bearer ' + open('openai_api_key').read()} data = json.dumps({"model": model, 'prompt': sentence, 'max_tokens': 6, 'temperature': 0}) response = requests.post(request_url, headers=headers, data=data) response = json.loads(response.text) print(response['choices'][0]['text']) def use_fine_tune_model_02(): openai.api_key = open('openai_api_key').read() model = "curie:ft-personal-2023-04-09-09-57-44" sentence = '酒店服务态度极差,设施很差,建议还是不要到那儿去。 ->' outputs = openai.Completion.create(model=model, prompt=sentence, max_tokens=6, temperature=0) print(outputs['choices'][0]['text']) if __name__ == '__main__': use_fine_tune_model_01() use_fine_tune_model_02()
程序输出结果:
消极 消极