人工智能服务
>
模型训练
>
设置断点续训
在创建训练作业时,在“自定义挂载”中指定文件管理中模型文件存储源路径与训练环境的挂载路径,并在训练代码中指定checkpoint存储路径为挂载路径。系统在训练作业启动时,自动将数据存储位置中的Checkpoint文件挂载到训练容器的本地目录。
PyTorch模型保存有两种方式:
state_dict = model.state_dict()
torch.save(state_dict, path)
torch.save(model, path)
checkpoint = {
"net": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch
}
if not os.path.isdir('model_save_dir'):
os.makedirs('model_save_dir')
torch.save(checkpoint,'model_save_dir/ckpt_{}.pth'.format(str(epoch)))
import os
import argparse
import torch
from datetime import datetime
parser = argparse.ArgumentParser()
parser.add_argument("--train_url", type=str)
parser.add_argument("--epochs", type=int, default=100)
args, unparsed = parser.parse_known_args()
# train_url 将被赋值为"/home/ma-user/modelarts/outputs/train_url_0"
train_url = args.train_url
start_epoch = 0
# 判断输出路径中是否有模型文件。如果无文件则默认从头训练,如果有模型文件,则加载epoch值最大的ckpt文件当做预训练模型。
if os.listdir(train_url):
print('> load last ckpt and continue training!!')
ckpt_files = [file for file in os.listdir(train_url) if file.endswith(".pth")]
last_ckpt = sorted(ckpt_files, key=lambda x: int(x.split('_')[-1].split('.')[0]))[-1]
local_ckpt_file = os.path.join(train_url, last_ckpt)
print('last_ckpt:', last_ckpt)
# 加载断点
checkpoint = torch.load(local_ckpt_file)
# 加载模型可学习参数
model.load_state_dict(checkpoint['net'])
# 加载优化器参数
optimizer.load_state_dict(checkpoint['optimizer'])
# 获取保存的epoch,模型会在此epoch的基础上继续训练
start_epoch = checkpoint['epoch']
start = datetime.now()
total_step = len(train_loader)
for epoch in range(start_epoch + 1, args.epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
...
# 保存模型训练过程中的网络权重、优化器权重、以及epoch
checkpoint = {
"net": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch
}
if not os.path.isdir(train_url):
os.makedirs(train_url)
torch.save(checkpoint, os.path.join(train_url, 'ckpt_best_{}.pth'.format(epoch)))
以下是TensorFlow实现断点续训的完整示例,参考PyTorch版本的结构设计,同时遵循TensorFlow的API特性:
checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(
checkpoint,
directory=args.train_url,
max_to_keep=5,
checkpoint_name='ckpt'
)
save_path = manager.save(checkpoint_number=epoch+n)
latest_checkpoint = manager.latest_checkpoint
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
start_epoch = int(latest_checkpoint.split('-')[-1])
import os
import argparse
import tensorflow as tf
from datetime import datetime
# 参数解析
parser = argparse.ArgumentParser()
parser.add_argument("--train_url", type=str, default="./checkpoints", help="模型保存路径")
parser.add_argument("--epochs", type=int, default=100, help="训练总轮次")
args = parser.parse_args()
# 创建保存目录
os.makedirs(args.train_url, exist_ok=True)
# 构建模型(以MNIST分类为例)
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
return model
# 创建或恢复模型
model = create_model()
# 定义检查点和管理器
checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(
checkpoint,
directory=args.train_url,
max_to_keep=5, # 保留最近的5个检查点
checkpoint_name='ckpt'
)
# 尝试恢复最新检查点
latest_checkpoint = manager.latest_checkpoint
start_epoch = 0
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
# 从检查点路径解析epoch(假设格式为 ckpt-{epoch})
start_epoch = int(latest_checkpoint.split('-')[-1])
print(f"已恢复检查点: {latest_checkpoint}, 从 epoch {start_epoch} 继续训练")
else:
print("未找到检查点,从头开始训练")
# 加载MNIST数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
# 创建数据加载器
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(10000).batch(32)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(32)
# 自定义训练循环(展示低级API用法)
def train():
# 定义损失函数和优化器
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()
# 定义指标
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
# 定义训练步骤
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = loss_fn(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
# 定义测试步骤
@tf.function
def test_step(images, labels):
predictions = model(images, training=False)
t_loss = loss_fn(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
# 开始训练
print(f"训练开始: 总轮次={args.epochs}, 从 epoch {start_epoch+1} 开始")
for epoch in range(start_epoch, args.epochs):
# 重置指标
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()
# 训练一个轮次
for images, labels in train_dataset:
train_step(images, labels)
# 测试
for test_images, test_labels in test_dataset:
test_step(test_images, test_labels)
# 打印进度
print(f'Epoch {epoch+1}, '
f'Loss: {train_loss.result()}, '
f'Accuracy: {train_accuracy.result() * 100}, '
f'Test Loss: {test_loss.result()}, '
f'Test Accuracy: {test_accuracy.result() * 100}')
# 保存检查点
save_path = manager.save(checkpoint_number=epoch+1)
print(f"Epoch {epoch+1} 检查点已保存: {save_path}")
# 执行训练
train()