Shortcuts

pybind11_ke.config.TrainerDataParallel 源代码

# coding:utf-8
#
# pybind11_ke/config/TrainerDataParallel.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on July 5, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 6, 2023
#
# 该脚本定义了并行训练循环函数.

"""
trainer_distributed_data_parallel - 并行训练循环函数。
"""

import os
import torch
import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group
from .Tester import Tester
from .Trainer import Trainer
from ..data import TrainDataLoader, TestDataLoader
from ..module.strategy import NegativeSampling

def ddp_setup(rank: int, world_size: int):

	"""
	构建进程组。在 Windows 上, :py:mod:`torch.distributed` 仅支持 `Gloo` 后端。

	:param rank: 进程的唯一标识符
	:type rank: int
	:param world_size: 进程的总数
	:type world_size: int
	"""
	
	os.environ["MASTER_ADDR"] = "localhost"
	os.environ["MASTER_PORT"] = "12355"
	init_process_group(backend="gloo", rank=rank, world_size=world_size)
	torch.cuda.set_device(rank)

def train(
	rank: int,
	world_size: int,
	model: NegativeSampling,
	data_loader: TrainDataLoader,
	epochs: int,
	lr: float,
	opt_method: str,
	test: bool,
	valid_interval: int | None,
	log_interval: int | None,
	save_interval: int | None,
	save_path: str | None,
	use_early_stopping: bool,
	metric: str,
	patience: int,
	delta: float,
	valid_file: str,
	test_file: str,
	type_constrain: bool,
	use_wandb: bool):

	"""进程函数。

	:param gpu_id: 第几个 gpu
	:type gpu_id: int
	:param world_size: 进程的总数
	:type world_size: int
	:param model: 包装 KGE 模型的训练策略类
	:type model: :py:class:`pybind11_ke.module.strategy.NegativeSampling`
	:param data_loader: TrainDataLoader
	:type data_loader: :py:class:`pybind11_ke.data.TrainDataLoader`
	:param epochs: 训练轮次数
	:type epochs: int
	:param lr: 学习率
	:type lr: float
	:param opt_method: 优化器: ``Adam`` or ``adam``, ``SGD`` or ``sgd``
	:type opt_method: str
	:param test: 是否在测试集上评估模型
	:type test: bool
	:param valid_interval: 训练几轮在验证集上评估一次模型
	:type valid_interval: int
	:param log_interval: 训练几轮输出一次日志
	:type log_interval: int
	:param save_interval: 训练几轮保存一次模型
	:type save_interval: int
	:param save_path: 模型保存的路径
	:type save_path: str
	:param use_early_stopping: 是否启用早停
	:type use_early_stopping: bool
	:param metric: 早停使用的验证指标,可选值:'mrr', 'hit1', 'hit3', 'hit10', 'mrTC', 'mrrTC', 'hit1TC', 'hit3TC', 'hit10TC'。
		'mrTC', 'mrrTC', 'hit1TC', 'hit3TC', 'hit10TC' 需要 :py:attr:`pybind11_ke.data.TestDataLoader.type_constrain` 为 True。
	:type metric: str
	:param patience: :py:attr:`pybind11_ke.utils.EarlyStopping.patience` 参数,上次验证得分改善后等待多长时间。
	:type patience: int
	:param delta: :py:attr:`pybind11_ke.utils.EarlyStopping.delta` 参数,监测数量的最小变化才符合改进条件。
	:type delta: float
	:param valid_file: valid2id.txt
	:type valid_file: str
	:param test_file: test2id.txt
	:type test_file: str
	:param type_constrain: 是否用 type_constrain.txt 进行负采样
	:type type_constrain: bool
	:param use_wandb: 是否启用 wandb 进行日志输出
	:type use_wandb: bool
	"""
	
	ddp_setup(rank, world_size)
	tester = None
	if test:
		test_dataloader = TestDataLoader(in_path = data_loader.in_path, ent_file = data_loader.ent_file,
			rel_file = data_loader.rel_file, train_file = data_loader.train_file, valid_file = valid_file,
			test_file = test_file, type_constrain = type_constrain)
		tester = Tester(model = model.model, data_loader = test_dataloader)
	trainer = Trainer(
		model=model,
		data_loader=data_loader,
		epochs=epochs,
		lr=lr,
		opt_method=opt_method,
		tester=tester,
		test=test,
		valid_interval=valid_interval,
		log_interval=log_interval,
		save_interval=save_interval,
		save_path=save_path,
		use_early_stopping=use_early_stopping,
		metric=metric,
		patience=patience,
		delta=delta,
		use_wandb=use_wandb,
		gpu_id=rank)
	try:
		trainer.run()
	except RuntimeError as err:
		print("RuntimeError:", err)
		print(f"[GPU0] Due to the lack of improvement in validation scores, an early stop has been made, and [GPU{rank}] also needs to stop model training.")
	finally:
		destroy_process_group()
	
[文档]def trainer_distributed_data_parallel( model = None, data_loader: TrainDataLoader | None = None, epochs: int = 1000, lr: float = 0.5, opt_method: str = "Adam", test: bool = False, valid_interval: int | None = None, log_interval: int | None = None, save_interval: int | None = None, save_path: str | None = None, use_early_stopping: bool = True, metric: str = 'hit10', patience: int = 2, delta: float = 0, valid_file: str = "valid2id.txt", test_file: str = "test2id.txt", type_constrain: bool = True, use_wandb: bool = False): """并行训练循环函数,用于生成单独子进程进行训练模型。 :py:mod:`torch.multiprocessing` 是 Python 原生 ``multiprocessing`` 的一个 ``PyTorch`` 的包装器。 ``multiprocessing`` 的生成进程函数必须由 ``if __name__ == '__main__'`` 保护。 有效的 batch size 是 :py:attr:`pybind11_ke.data.TrainDataLoader.batch_size` * ``nprocs``。 例子:: from pybind11_ke.config import trainer_distributed_data_parallel if __name__ == "__main__": trainer_distributed_data_parallel(model = model, data_loader = train_dataloader, epochs = 1000, lr = 0.02, opt_method = "adam", test = True, valid_interval = 100, log_interval = 100, save_interval = 100, save_path = "../../checkpoint/transe.pth", type_constrain = False) :param model: 包装 KGE 模型的训练策略类 :type model: :py:class:`pybind11_ke.module.strategy.NegativeSampling` :param data_loader: TrainDataLoader :type data_loader: :py:class:`pybind11_ke.data.TrainDataLoader` :param epochs: 训练轮次数 :type epochs: int :param lr: 学习率 :type lr: float :param opt_method: 优化器: ``Adam`` or ``adam``, ``SGD`` or ``sgd`` :type opt_method: str :param test: 是否在测试集上评估模型 :type test: bool :param valid_interval: 训练几轮在验证集上评估一次模型 :type valid_interval: int :param log_interval: 训练几轮输出一次日志 :type log_interval: int :param save_interval: 训练几轮保存一次模型 :type save_interval: int :param save_path: 模型保存的路径 :type save_path: str :param use_early_stopping: 是否启用早停,需要 :py:attr:`tester` 和 :py:attr:`save_path` 不为空 :type use_early_stopping: bool :param metric: 早停使用的验证指标,可选值:'mrr', 'hit1', 'hit3', 'hit10', 'mrTC', 'mrrTC', 'hit1TC', 'hit3TC', 'hit10TC'。 'mrTC', 'mrrTC', 'hit1TC', 'hit3TC', 'hit10TC' 需要 :py:attr:`pybind11_ke.data.TestDataLoader.type_constrain` 为 True。默认值:'hit10' :type metric: str :param patience: :py:attr:`pybind11_ke.utils.EarlyStopping.patience` 参数,上次验证得分改善后等待多长时间。默认值:2 :type patience: int :param delta: :py:attr:`pybind11_ke.utils.EarlyStopping.delta` 参数,监测数量的最小变化才符合改进条件。默认值:0 :type delta: float :param valid_file: valid2id.txt :type valid_file: str :param test_file: test2id.txt :type test_file: str :param type_constrain: 是否用 type_constrain.txt 进行负采样 :type type_constrain: bool :param use_wandb: 是否启用 wandb 进行日志输出 :type use_wandb: bool """ world_size = torch.cuda.device_count() mp.spawn(train, args = (world_size, model, data_loader, epochs, lr, opt_method, test, valid_interval, log_interval, save_interval, save_path, use_early_stopping, metric, patience, delta, valid_file, test_file, type_constrain, use_wandb), nprocs = world_size)

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs