Shortcuts

trainer_distributed_data_parallel

pybind11_ke.config.trainer_distributed_data_parallel(model=None, data_loader: TrainDataLoader | None = None, epochs: int = 1000, lr: float = 0.5, opt_method: str = 'Adam', test: bool = False, valid_interval: int | None = None, log_interval: int | None = None, save_interval: int | None = None, save_path: str | None = None, use_early_stopping: bool = True, metric: str = 'hit10', patience: int = 2, delta: float = 0, valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt', type_constrain: bool = True, use_wandb: bool = False)[源代码]

并行训练循环函数,用于生成单独子进程进行训练模型。

torch.multiprocessing 是 Python 原生 multiprocessing 的一个 PyTorch 的包装器。 multiprocessing 的生成进程函数必须由 if __name__ == '__main__' 保护。 有效的 batch size 是 pybind11_ke.data.TrainDataLoader.batch_size * nprocs

例子:

from pybind11_ke.config import trainer_distributed_data_parallel

if __name__ == "__main__":

        trainer_distributed_data_parallel(model = model, data_loader = train_dataloader,
                epochs = 1000, lr = 0.02, opt_method = "adam",
                test = True, valid_interval = 100, log_interval = 100, save_interval = 100,
                save_path = "../../checkpoint/transe.pth", type_constrain = False)
参数:
  • model (pybind11_ke.module.strategy.NegativeSampling) – 包装 KGE 模型的训练策略类

  • data_loader (pybind11_ke.data.TrainDataLoader) – TrainDataLoader

  • epochs (int) – 训练轮次数

  • lr (float) – 学习率

  • opt_method (str) – 优化器: Adam or adam, SGD or sgd

  • test (bool) – 是否在测试集上评估模型

  • valid_interval (int) – 训练几轮在验证集上评估一次模型

  • log_interval (int) – 训练几轮输出一次日志

  • save_interval (int) – 训练几轮保存一次模型

  • save_path (str) – 模型保存的路径

  • use_early_stopping (bool) – 是否启用早停,需要 testersave_path 不为空

  • metric (str) – 早停使用的验证指标,可选值:’mrr’, ‘hit1’, ‘hit3’, ‘hit10’, ‘mrTC’, ‘mrrTC’, ‘hit1TC’, ‘hit3TC’, ‘hit10TC’。 ‘mrTC’, ‘mrrTC’, ‘hit1TC’, ‘hit3TC’, ‘hit10TC’ 需要 pybind11_ke.data.TestDataLoader.type_constrain 为 True。默认值:’hit10’

  • patience (int) – pybind11_ke.utils.EarlyStopping.patience 参数,上次验证得分改善后等待多长时间。默认值:2

  • delta (float) – pybind11_ke.utils.EarlyStopping.delta 参数,监测数量的最小变化才符合改进条件。默认值:0

  • valid_file (str) – valid2id.txt

  • test_file (str) – test2id.txt

  • type_constrain (bool) – 是否用 type_constrain.txt 进行负采样

  • use_wandb (bool) – 是否启用 wandb 进行日志输出

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs