使用 torchrun 运行 ChatGLM2-6B 微调预测任务

本文档提供使用 torchrun 命令运行 ChatGLM2-6B 模型微调的预测任务的示例,包括 .sh 和 .bat 文件两种格式。

环境配置:

  • 已安装 torchrun
  • 已安装 Python 3.x
  • 已安装 PyTorch
  • 已安装 transformers 库
  • 已安装必要的依赖库

数据集:

  • AdvertiseGen/dev.json: 训练和测试数据集

模型:

  • THUDM/chatglm2-6b: 预训练模型

参数:

  • PRE_SEQ_LEN: 序列长度,默认值为 128
  • CHECKPOINT: 微调 checkpoint 文件夹名称,默认值为 adgen-chatglm2-6b-pt-128-2e-2
  • STEP: 微调步骤,默认值为 3000
  • NUM_GPUS: 使用的 GPU 数量,默认值为 1

运行指令:

.sh 文件

PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm2-6b-pt-128-2e-2
STEP=3000
NUM_GPUS=1

torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \  
    --do_predict \  
    --validation_file AdvertiseGen/dev.json \  
    --test_file AdvertiseGen/dev.json \  
    --overwrite_cache \  
    --prompt_column content \  
    --response_column summary \  
    --model_name_or_path THUDM/chatglm2-6b \  
    --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \  
    --output_dir ./output/$CHECKPOINT \  
    --overwrite_output_dir \  
    --max_source_length 64 \  
    --max_target_length 64 \  
    --per_device_eval_batch_size 1 \  
    --predict_with_generate \  
    --pre_seq_len $PRE_SEQ_LEN \  
    --quantization_bit 4

.bat 文件

@echo off
set PRE_SEQ_LEN=128
set CHECKPOINT=adgen-chatglm2-6b-pt-128-2e-2
set STEP=3000
set NUM_GPUS=1

torchrun --standalone --nnodes=1 --nproc-per-node=%NUM_GPUS% main.py ^
    --do_predict ^
    --validation_file AdvertiseGen/dev.json ^
    --test_file AdvertiseGen/dev.json ^
    --overwrite_cache ^
    --prompt_column content ^
    --response_column summary ^
    --model_name_or_path THUDM/chatglm2-6b ^
    --ptuning_checkpoint ./output/%CHECKPOINT%/checkpoint-%STEP% ^
    --output_dir ./output/%CHECKPOINT% ^
    --overwrite_output_dir ^
    --max_source_length 64 ^
    --max_target_length 64 ^
    --per_device_eval_batch_size 1 ^
    --predict_with_generate ^
    --pre_seq_len %PRE_SEQ_LEN% ^
    --quantization_bit 4

解释:

  • torchrun 命令用于分布式训练,--standalone 参数表示使用单节点训练,--nnodes=1 表示使用一个节点,--nproc-per-node 参数表示每个节点使用的 GPU 数量。
  • --do_predict 参数表示执行预测任务。
  • --validation_file--test_file 参数指定验证集和测试集文件路径。
  • --overwrite_cache 参数表示覆盖缓存文件。
  • --prompt_column--response_column 参数指定输入和输出列名。
  • --model_name_or_path 参数指定预训练模型路径。
  • --ptuning_checkpoint 参数指定微调 checkpoint 路径。
  • --output_dir 参数指定输出文件夹路径。
  • --overwrite_output_dir 参数表示覆盖输出文件夹。
  • --max_source_length--max_target_length 参数指定输入和输出序列最大长度。
  • --per_device_eval_batch_size 参数指定每个 GPU 的评估批次大小。
  • --predict_with_generate 参数表示使用生成方式进行预测。
  • --pre_seq_len 参数指定预处理序列长度。
  • --quantization_bit 参数指定量化位数。

注意:

  • 以上代码中的路径需要根据实际情况进行调整。
  • 可以根据需要修改参数值。
  • 在运行之前,请确保已安装必要的依赖库。
ChatGLM2-6B 模型微调预测任务 - 使用 torchrun 运行 .sh 和 .bat 文件

原文地址: https://www.cveoy.top/t/topic/qjvk 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录