trainer_distributed_data_parallel¶
- pybind11_ke.config.trainer_distributed_data_parallel(model: NegativeSampling | None = None, train_dataloader: torch.utils.data.DataLoader | None = None, val_dataloader: torch.utils.data.DataLoader | None = None, test_dataloader: torch.utils.data.DataLoader | None = None, epochs: int = 1000, lr: float = 0.5, opt_method: str = 'Adam', prediction: str = 'all', use_tqdm: bool = True, valid_interval: int | None = None, log_interval: int | None = None, save_interval: int | None = None, save_path: str | None = None, use_early_stopping: bool = True, metric: str = 'hits@10', patience: int = 2, delta: float = 0, use_wandb: bool = False)[源代码]¶
并行训练循环函数,用于生成单独子进程进行训练模型。
torch.multiprocessing是 Python 原生multiprocessing的一个PyTorch的包装器。multiprocessing的生成进程函数必须由if __name__ == '__main__'保护。 有效的 batch size 是pybind11_ke.data.TrainDataLoader.batch_size*nprocs。例子:
from pybind11_ke.config import trainer_distributed_data_parallel if __name__ == "__main__": trainer_distributed_data_parallel(model = model, data_loader = train_dataloader, epochs = 1000, lr = 0.02, opt_method = "adam", test = True, valid_interval = 100, log_interval = 100, save_interval = 100, save_path = "../../checkpoint/transe.pth", type_constrain = False)
- 参数:
model (
pybind11_ke.module.strategy.NegativeSampling) – 包装 KGE 模型的训练策略类train_dataloader (
pybind11_ke.data.KGEDataLoader) – KGEDataLoaderepochs (int) – 训练轮次数
lr (float) – 学习率
opt_method (str) – 优化器:
Adamoradam,SGDorsgdprediction (str) – 链接预测模式: ‘all’、’head’、’tail’
use_tqdm (bool) – 是否启用进度条
valid_interval (int) – 训练几轮在验证集上评估一次模型
log_interval (int) – 训练几轮输出一次日志
save_interval (int) – 训练几轮保存一次模型
save_path (str) – 模型保存的路径
use_early_stopping (bool) – 是否启用早停,需要
tester和save_path不为空metric (str) – 早停使用的验证指标,可选值:’mr’, ‘mrr’, ‘hits@N’。默认值:’hits@10’
patience (int) –
pybind11_ke.utils.EarlyStopping.patience参数,上次验证得分改善后等待多长时间。默认值:2delta (float) –
pybind11_ke.utils.EarlyStopping.delta参数,监测数量的最小变化才符合改进条件。默认值:0use_wandb (bool) – 是否启用 wandb 进行日志输出