Skip to content

设置断点续训

一、挂载Checkpoint挂载目录

在创建训练作业时,在“自定义挂载”中指定文件管理中模型文件存储源路径与训练环境的挂载路径,并在训练代码中指定checkpoint存储路径为挂载路径。系统在训练作业启动时,自动将数据存储位置中的Checkpoint文件挂载到训练容器的本地目录。

1.png

二、PyTorch设置断点续训

PyTorch模型保存有两种方式:

  • 仅保存模型参数
shell
state_dict = model.state_dict()
torch.save(state_dict, path)
  • 保存整个Model(不推荐)
shell
torch.save(model, path)
  • 可根据step步数、时间等周期性保存模型的训练过程的产物。
    • 将模型训练过程中的网络权重、优化器权重、以及epoch进行保存,便于中断后继续训练恢复。
shell

   checkpoint = {
           "net": model.state_dict(),
           "optimizer": optimizer.state_dict(),
           "epoch": epoch   
   }
   if not os.path.isdir('model_save_dir'):
       os.makedirs('model_save_dir')
   torch.save(checkpoint,'model_save_dir/ckpt_{}.pth'.format(str(epoch)))
  • 完整代码示例。
shell
import os
import argparse
import torch
from datetime import datetime
parser = argparse.ArgumentParser()
parser.add_argument("--train_url", type=str)
parser.add_argument("--epochs", type=int, default=100)
args, unparsed = parser.parse_known_args()
# train_url 将被赋值为"/home/ma-user/modelarts/outputs/train_url_0" 
train_url = args.train_url
start_epoch = 0
# 判断输出路径中是否有模型文件。如果无文件则默认从头训练,如果有模型文件,则加载epoch值最大的ckpt文件当做预训练模型。
if os.listdir(train_url):
    print('> load last ckpt and continue training!!')
    ckpt_files = [file for file in os.listdir(train_url) if file.endswith(".pth")]
    last_ckpt = sorted(ckpt_files, key=lambda x: int(x.split('_')[-1].split('.')[0]))[-1]
    local_ckpt_file = os.path.join(train_url, last_ckpt)
    print('last_ckpt:', last_ckpt)
    # 加载断点
    checkpoint = torch.load(local_ckpt_file)  
    # 加载模型可学习参数
    model.load_state_dict(checkpoint['net'])  
    # 加载优化器参数
    optimizer.load_state_dict(checkpoint['optimizer'])  
    # 获取保存的epoch,模型会在此epoch的基础上继续训练
    start_epoch = checkpoint['epoch']  
start = datetime.now()
total_step = len(train_loader)
for epoch in range(start_epoch + 1, args.epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ...

    # 保存模型训练过程中的网络权重、优化器权重、以及epoch
    checkpoint = {
          "net": model.state_dict(),
          "optimizer": optimizer.state_dict(),
          "epoch": epoch
        }
    if not os.path.isdir(train_url):
        os.makedirs(train_url)
    torch.save(checkpoint, os.path.join(train_url, 'ckpt_best_{}.pth'.format(epoch)))

以下是TensorFlow实现断点续训的完整示例,参考PyTorch版本的结构设计,同时遵循TensorFlow的API特性:

三、TensorFlow容错训练示例

  1. Checkpoint管理
shell
checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(
    checkpoint, 
    directory=args.train_url, 
    max_to_keep=5,
    checkpoint_name='ckpt'
)
  • 使用Checkpoint对象跟踪模型状态
  • 使用CheckpointManager自动管理检查点版本
  1. 检查点保存
shell
save_path = manager.save(checkpoint_number=epoch+n)
  • 每n个epoch保存一次检查点,请酌情修改
  • 自动添加版本号
  • 自动清理旧检查点
  1. 恢复逻辑
shell
latest_checkpoint = manager.latest_checkpoint
if latest_checkpoint:
    checkpoint.restore(latest_checkpoint)
    start_epoch = int(latest_checkpoint.split('-')[-1])
  • 自动查找最新检查点
  • 从检查点路径提取epoch信息
  1. 完整代码示例
shell
import os
import argparse
import tensorflow as tf
from datetime import datetime

# 参数解析
parser = argparse.ArgumentParser()
parser.add_argument("--train_url", type=str, default="./checkpoints", help="模型保存路径")
parser.add_argument("--epochs", type=int, default=100, help="训练总轮次")
args = parser.parse_args()

# 创建保存目录
os.makedirs(args.train_url, exist_ok=True)

# 构建模型(以MNIST分类为例)
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10)
    ])
    model.compile(
        optimizer='adam',
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )
    return model

# 创建或恢复模型
model = create_model()

# 定义检查点和管理器
checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(
    checkpoint, 
    directory=args.train_url, 
    max_to_keep=5,  # 保留最近的5个检查点
    checkpoint_name='ckpt'
)

# 尝试恢复最新检查点
latest_checkpoint = manager.latest_checkpoint
start_epoch = 0

if latest_checkpoint:
    checkpoint.restore(latest_checkpoint)
    # 从检查点路径解析epoch(假设格式为 ckpt-{epoch})
    start_epoch = int(latest_checkpoint.split('-')[-1])
    print(f"已恢复检查点: {latest_checkpoint}, 从 epoch {start_epoch} 继续训练")
else:
    print("未找到检查点,从头开始训练")

# 加载MNIST数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0

# 创建数据加载器
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(10000).batch(32)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(32)

# 自定义训练循环(展示低级API用法)
def train():
    # 定义损失函数和优化器
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = tf.keras.optimizers.Adam()
    
    # 定义指标
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
    
    # 定义训练步骤
    @tf.function
    def train_step(images, labels):
        with tf.GradientTape() as tape:
            predictions = model(images, training=True)
            loss = loss_fn(labels, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        train_loss(loss)
        train_accuracy(labels, predictions)
    
    # 定义测试步骤
    @tf.function
    def test_step(images, labels):
        predictions = model(images, training=False)
        t_loss = loss_fn(labels, predictions)
        
        test_loss(t_loss)
        test_accuracy(labels, predictions)
    
    # 开始训练
    print(f"训练开始: 总轮次={args.epochs}, 从 epoch {start_epoch+1} 开始")
    for epoch in range(start_epoch, args.epochs):
        # 重置指标
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()
        
        # 训练一个轮次
        for images, labels in train_dataset:
            train_step(images, labels)
        
        # 测试
        for test_images, test_labels in test_dataset:
            test_step(test_images, test_labels)
        
        # 打印进度
        print(f'Epoch {epoch+1}, '
              f'Loss: {train_loss.result()}, '
              f'Accuracy: {train_accuracy.result() * 100}, '
              f'Test Loss: {test_loss.result()}, '
              f'Test Accuracy: {test_accuracy.result() * 100}')
        
        # 保存检查点
        save_path = manager.save(checkpoint_number=epoch+1)
        print(f"Epoch {epoch+1} 检查点已保存: {save_path}")

# 执行训练
train()