Shortcuts

Trainer

class pybind11_ke.config.Trainer(model: Strategy | None = None, data_loader: torch.utils.data.DataLoader | None = None, epochs: int = 1000, lr: float = 0.5, opt_method: str = 'Adam', accelerator: accelerate.Accelerator | None = None, use_gpu: bool = True, device: str = 'cuda:0', tester: Tester | None = None, test: bool = False, valid_interval: int | None = None, log_interval: int | None = None, save_interval: int | None = None, save_path: str | None = None, use_early_stopping: bool = True, metric: str = 'hits@10', patience: int = 2, delta: float = 0, use_wandb: bool = False)[源代码]

主要用于 KGE 模型的训练。

例子:

from pybind11_ke.data import KGEDataLoader, BernSampler, TradTestSampler
from pybind11_ke.module.model import TransE
from pybind11_ke.module.loss import MarginLoss
from pybind11_ke.module.strategy import NegativeSampling
from pybind11_ke.config import Trainer, Tester

# dataloader for training
dataloader = KGEDataLoader(
        in_path = "../../benchmarks/FB15K/", 
        batch_size = 8192,
        neg_ent = 25,
        test = True,
        test_batch_size = 256,
        num_workers = 16,
        train_sampler = BernSampler,
        test_sampler = TradTestSampler
)

# define the model
transe = TransE(
        ent_tol = dataloader.get_ent_tol(),
        rel_tol = dataloader.get_rel_tol(),
        dim = 50, 
        p_norm = 1, 
        norm_flag = True)

# define the loss function
model = NegativeSampling(
        model = transe, 
        loss = MarginLoss(margin = 1.0),
        regul_rate = 0.01
)

# test the model
tester = Tester(model = transe, data_loader = dataloader, use_gpu = True, device = 'cuda:1')

# train the model
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
        epochs = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1',
        tester = tester, test = True, valid_interval = 100,
        log_interval = 100, save_interval = 100,
        save_path = '../../checkpoint/transe.pth', delta = 0.01)
trainer.run()
__init__(model: Strategy | None = None, data_loader: torch.utils.data.DataLoader | None = None, epochs: int = 1000, lr: float = 0.5, opt_method: str = 'Adam', accelerator: accelerate.Accelerator | None = None, use_gpu: bool = True, device: str = 'cuda:0', tester: Tester | None = None, test: bool = False, valid_interval: int | None = None, log_interval: int | None = None, save_interval: int | None = None, save_path: str | None = None, use_early_stopping: bool = True, metric: str = 'hits@10', patience: int = 2, delta: float = 0, use_wandb: bool = False)[源代码]

创建 Trainer 对象。

参数:
__weakref__

list of weak references to the object (if defined)

accelerator

是否进行分布式并行训练,pybind11_ke.config.accelerator_prepare() 返回列表中的最后一个元素。

configure_optimizers()[源代码]

可以通过重新实现该方法自定义配置优化器。

data_loader: torch.utils.data.DataLoader

__init__() 传入的 torch.utils.data.DataLoader

delta: float

pybind11_ke.utils.EarlyStopping.delta 参数,监测数量的最小变化才符合改进条件。默认值:0

device: Union[torch.device, str]

gpu,利用 device 构造的 torch.device 对象

early_stopping: EarlyStopping

早停对象

epochs: int

epochs

get_device() torch.device | str[源代码]

返回当前进程的设备。

返回:

设备信息

返回类型:

Union[torch.device, str]

get_model() Model[源代码]

返回原始的 KGE 模型。

返回:

KGE 模型

返回类型:

pybind11_ke.module.model.Model

is_local_main_process() bool[源代码]

当前进程是否是主进程。

返回:

当前进程是否是主进程。

返回类型:

bool

log_interval: int | None

训练几轮输出一次日志

lr: float

学习率

metric: str

早停使用的验证指标,可选值:’mr’, ‘mrr’, ‘hits@N’, ‘mr_type’, ‘mrr_type’, ‘hits@N_type’。默认值:’hits@10’

model: Strategy

包装 KGE 模型的训练策略类,即 pybind11_ke.module.strategy.Strategy

opt_method: str

用户传入的优化器名字字符串

optimizer: torch.optim.SGD | torch.optim.Adagrad | torch.optim.Adam | None

根据 __init__()opt_method 生成对应的优化器

patience: int

pybind11_ke.utils.EarlyStopping.patience 参数,上次验证得分改善后等待多长时间。默认值:2

print_test(sampling_mode: str, epoch: int = 0)[源代码]

根据 tester 类型进行链接预测 。

参数:

sampling_mode (str) – 数据

run()[源代码]

训练循环,首先根据 use_gpu 设置 model 是否使用 gpu 训练,然后根据 opt_method 设置 optimizer,最后迭代 data_loader 获取数据, 并利用 train_one_step() 训练。

save_interval: int | None

训练几轮保存一次模型

save_path: str | None

模型保存的路径

scheduler: torch.optim.lr_scheduler.MultiStepLR | None

学习率调度器

test: bool

是否在测试集上评估模型, tester 不为空

tester: Tester | None

用于模型评估的验证模型类

to_var(x: torch.Tensor) torch.Tensor[源代码]

x 转移到对应的设备上。

参数:

x (torch.Tensor) – 数据

返回:

张量

返回类型:

torch.Tensor

train_one_step(data: dict[str, Union[str, dgl.DGLGraph, torch.Tensor]]) float[源代码]

根据 data_loader 生成的 1 批次(batch) data 将 模型训练 1 步。

参数:

data (dict[str, Union[dgl.DGLGraph, torch.Tensor]]) – 训练数据

返回:

损失值

返回类型:

float

use_early_stopping: bool

是否启用早停,需要 testersave_path 不为空

use_gpu: bool

是否使用 gpu

use_wandb: bool

是否启用 wandb 进行日志输出

valid_interval: int | None

训练几轮在验证集上评估一次模型, tester 不为空

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs