Shortcuts

GraphTrainer

class pybind11_ke.config.GraphTrainer(model: RGCNSampling | CompGCNSampling | None = None, data_loader: torch.utils.data.DataLoader | None = None, epochs: int = 1000, lr: float = 0.5, opt_method: str = 'Adam', use_gpu: bool = True, device: str = 'cuda:0', tester: GraphTester | None = None, test: bool = False, valid_interval: int | None = None, log_interval: int | None = None, save_interval: int | None = None, save_path: str | None = None, use_early_stopping: bool = True, metric: str = 'hit10', patience: int = 2, delta: float = 0, use_wandb: bool = False, gpu_id: int | None = None)[源代码]

主要用于 R-GCN [SKB+18]CompGCN [VSNT20] 的训练。

例子:

from pybind11_ke.data import CompGCNSampler, CompGCNTestSampler, GraphDataLoader
from pybind11_ke.module.model import CompGCN
from pybind11_ke.module.loss import Cross_Entropy_Loss
from pybind11_ke.module.strategy import CompGCNSampling
from pybind11_ke.config import GraphTrainer, GraphTester

dataloader = GraphDataLoader(
        in_path = "../../benchmarks/FB15K237/",
        batch_size = 2048,
        test_batch_size = 256,
        num_workers = 16,
        train_sampler = CompGCNSampler,
        test_sampler = CompGCNTestSampler
)

# define the model
compgcn = CompGCN(
        ent_tol = dataloader.train_sampler.ent_tol,
        rel_tol = dataloader.train_sampler.rel_tol,
        dim = 100
)

# define the loss function
model = CompGCNSampling(
        model = compgcn,
        loss = Cross_Entropy_Loss(model = compgcn),
        ent_tol = dataloader.train_sampler.ent_tol
)

# test the model
tester = GraphTester(model = compgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0', prediction = "tail")

# train the model
trainer = GraphTrainer(model = model, data_loader = dataloader.train_dataloader(),
        epochs = 2000, lr = 0.0001, use_gpu = True, device = 'cuda:0',
        tester = tester, test = True, valid_interval = 50, log_interval = 50,
        save_interval = 50, save_path = '../../checkpoint/compgcn.pth'
)
trainer.run()
__init__(model: RGCNSampling | CompGCNSampling | None = None, data_loader: torch.utils.data.DataLoader | None = None, epochs: int = 1000, lr: float = 0.5, opt_method: str = 'Adam', use_gpu: bool = True, device: str = 'cuda:0', tester: GraphTester | None = None, test: bool = False, valid_interval: int | None = None, log_interval: int | None = None, save_interval: int | None = None, save_path: str | None = None, use_early_stopping: bool = True, metric: str = 'hit10', patience: int = 2, delta: float = 0, use_wandb: bool = False, gpu_id: int | None = None)[源代码]

创建 GraphTrainer 对象。

参数:
__weakref__

list of weak references to the object (if defined)

configure_optimizers()

可以通过重新实现该方法自定义配置优化器。

data_loader: Union[TrainDataLoader, torch.utils.data.DataLoader, None]

__init__() 传入的 pybind11_ke.data.TrainDataLoader or torch.utils.data.DataLoader

delta: float

pybind11_ke.utils.EarlyStopping.delta 参数,监测数量的最小变化才符合改进条件。默认值:0

device: Union[torch.device, str]

gpu,利用 device 构造的 torch.device 对象

early_stopping: EarlyStopping

早停对象

epochs: int

epochs

get_model() Model

返回原始的 KGE 模型

gpu_id: int | None

第几个 gpu

log_interval: int | None

训练几轮输出一次日志

lr: float

学习率

metric: str

早停使用的验证指标,可选值:’mrr’, ‘hit1’, ‘hit3’, ‘hit10’, ‘mrTC’, ‘mrrTC’, ‘hit1TC’, ‘hit3TC’, ‘hit10TC’。 ‘mrTC’, ‘mrrTC’, ‘hit1TC’, ‘hit3TC’, ‘hit10TC’ 需要 pybind11_ke.data.TestDataLoader.type_constrain 为 True。默认值:’hit10’

model: torch.nn.parallel.DistributedDataParallel | NegativeSampling | RGCNSampling | CompGCNSampling | None

包装 KGE 模型的训练策略类,即 pybind11_ke.module.strategy.NegativeSampling or pybind11_ke.module.strategy.RGCNSampling or pybind11_ke.module.strategy.CompGCNSampling

opt_method: str

用户传入的优化器名字字符串

optimizer: torch.optim.SGD | torch.optim.Adagrad | torch.optim.Adam | None

根据 __init__()opt_method 生成对应的优化器

patience: int

pybind11_ke.utils.EarlyStopping.patience 参数,上次验证得分改善后等待多长时间。默认值:2

print_test(sampling_mode: str, epoch: int = 0)

根据 tester 类型进行链接预测 。

参数:

sampling_mode (str) – 数据

run()

训练循环,首先根据 use_gpu 设置 model 是否使用 gpu 训练,然后根据 opt_method 设置 optimizer,最后迭代 data_loader 获取数据, 并利用 train_one_step() 训练。

save_interval: int | None

训练几轮保存一次模型

save_path: str | None

模型保存的路径

scheduler: torch.optim.lr_scheduler.MultiStepLR | None

学习率调度器

test: bool

是否在测试集上评估模型, tester 不为空

tester: Tester | GraphTester | None

用于模型评估的验证模型类

to_var(x: torch.Tensor) torch.Tensor[源代码]

x 转移到对应的设备上。

参数:

x (torch.Tensor) – 数据

返回:

张量

返回类型:

torch.Tensor

train_one_step(data: dict[str, Union[dgl.DGLGraph, torch.Tensor]]) float[源代码]

根据 data_loader 生成的 1 批次(batch) data 将 模型训练 1 步。

参数:

data (dict[str, Union[dgl.DGLGraph , torch.Tensor]]) – data_loader 利用 pybind11_ke.data.GraphSampler.sampling() 函数生成的数据

返回:

损失值

返回类型:

float

use_early_stopping: bool

是否启用早停,需要 testersave_path 不为空

use_gpu: bool

是否使用 gpu

use_wandb: bool

是否启用 wandb 进行日志输出

valid_interval: int | None

训练几轮在验证集上评估一次模型, tester 不为空

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs