GraphTrainer¶
- class pybind11_ke.config.GraphTrainer(model: RGCNSampling | CompGCNSampling | None = None, data_loader: torch.utils.data.DataLoader | None = None, epochs: int = 1000, lr: float = 0.5, opt_method: str = 'Adam', use_gpu: bool = True, device: str = 'cuda:0', tester: GraphTester | None = None, test: bool = False, valid_interval: int | None = None, log_interval: int | None = None, save_interval: int | None = None, save_path: str | None = None, use_early_stopping: bool = True, metric: str = 'hit10', patience: int = 2, delta: float = 0, use_wandb: bool = False, gpu_id: int | None = None)[源代码]¶
主要用于
R-GCN[SKB+18] 和CompGCN[VSNT20] 的训练。例子:
from pybind11_ke.data import CompGCNSampler, CompGCNTestSampler, GraphDataLoader from pybind11_ke.module.model import CompGCN from pybind11_ke.module.loss import Cross_Entropy_Loss from pybind11_ke.module.strategy import CompGCNSampling from pybind11_ke.config import GraphTrainer, GraphTester dataloader = GraphDataLoader( in_path = "../../benchmarks/FB15K237/", batch_size = 2048, test_batch_size = 256, num_workers = 16, train_sampler = CompGCNSampler, test_sampler = CompGCNTestSampler ) # define the model compgcn = CompGCN( ent_tol = dataloader.train_sampler.ent_tol, rel_tol = dataloader.train_sampler.rel_tol, dim = 100 ) # define the loss function model = CompGCNSampling( model = compgcn, loss = Cross_Entropy_Loss(model = compgcn), ent_tol = dataloader.train_sampler.ent_tol ) # test the model tester = GraphTester(model = compgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0', prediction = "tail") # train the model trainer = GraphTrainer(model = model, data_loader = dataloader.train_dataloader(), epochs = 2000, lr = 0.0001, use_gpu = True, device = 'cuda:0', tester = tester, test = True, valid_interval = 50, log_interval = 50, save_interval = 50, save_path = '../../checkpoint/compgcn.pth' ) trainer.run()
- __init__(model: RGCNSampling | CompGCNSampling | None = None, data_loader: torch.utils.data.DataLoader | None = None, epochs: int = 1000, lr: float = 0.5, opt_method: str = 'Adam', use_gpu: bool = True, device: str = 'cuda:0', tester: GraphTester | None = None, test: bool = False, valid_interval: int | None = None, log_interval: int | None = None, save_interval: int | None = None, save_path: str | None = None, use_early_stopping: bool = True, metric: str = 'hit10', patience: int = 2, delta: float = 0, use_wandb: bool = False, gpu_id: int | None = None)[源代码]¶
创建 GraphTrainer 对象。
- 参数:
model (
pybind11_ke.module.strategy.RGCNSamplingorpybind11_ke.module.strategy.CompGCNSampling) – 包装 KGE 模型的训练策略类data_loader (torch.utils.data.DataLoader) – DataLoader
epochs (int) – 训练轮次数
lr (float) – 学习率
opt_method (str) – 优化器: ‘Adam’ or ‘adam’, ‘Adagrad’ or ‘adagrad’, ‘SGD’ or ‘sgd’
use_gpu (bool) – 是否使用 gpu
device (str) – 使用哪个 gpu
tester (
pybind11_ke.config.Tester) – 用于模型评估的验证模型类log_interval (int) – 训练几轮输出一次日志
save_interval (int) – 训练几轮保存一次模型
save_path (str) – 模型保存的路径
use_early_stopping (bool) – 是否启用早停,需要
tester和save_path不为空metric (str) – 早停使用的验证指标,可选值:’mr’, ‘mrr’, ‘hit1’, ‘hit3’, ‘hit10’。默认值:’hit10’
patience (int) –
pybind11_ke.utils.EarlyStopping.patience参数,上次验证得分改善后等待多长时间。默认值:2delta (float) –
pybind11_ke.utils.EarlyStopping.delta参数,监测数量的最小变化才符合改进条件。默认值:0use_wandb (bool) – 是否启用 wandb 进行日志输出
gpu_id (int) – 第几个 gpu,用于并行训练
- __weakref__¶
list of weak references to the object (if defined)
- configure_optimizers()¶
可以通过重新实现该方法自定义配置优化器。
- data_loader: Union[TrainDataLoader, torch.utils.data.DataLoader, None]¶
__init__()传入的pybind11_ke.data.TrainDataLoaderortorch.utils.data.DataLoader
- delta: float¶
pybind11_ke.utils.EarlyStopping.delta参数,监测数量的最小变化才符合改进条件。默认值:0
- device: Union[torch.device, str]¶
gpu,利用
device构造的torch.device对象
- early_stopping: EarlyStopping¶
早停对象
- metric: str¶
早停使用的验证指标,可选值:’mrr’, ‘hit1’, ‘hit3’, ‘hit10’, ‘mrTC’, ‘mrrTC’, ‘hit1TC’, ‘hit3TC’, ‘hit10TC’。 ‘mrTC’, ‘mrrTC’, ‘hit1TC’, ‘hit3TC’, ‘hit10TC’ 需要
pybind11_ke.data.TestDataLoader.type_constrain为 True。默认值:’hit10’
- model: torch.nn.parallel.DistributedDataParallel | NegativeSampling | RGCNSampling | CompGCNSampling | None¶
包装 KGE 模型的训练策略类,即
pybind11_ke.module.strategy.NegativeSamplingorpybind11_ke.module.strategy.RGCNSamplingorpybind11_ke.module.strategy.CompGCNSampling
- optimizer: torch.optim.SGD | torch.optim.Adagrad | torch.optim.Adam | None¶
根据
__init__()的opt_method生成对应的优化器
- patience: int¶
pybind11_ke.utils.EarlyStopping.patience参数,上次验证得分改善后等待多长时间。默认值:2
- run()¶
训练循环,首先根据
use_gpu设置model是否使用 gpu 训练,然后根据opt_method设置optimizer,最后迭代data_loader获取数据, 并利用train_one_step()训练。
- scheduler: torch.optim.lr_scheduler.MultiStepLR | None¶
学习率调度器
- tester: Tester | GraphTester | None¶
用于模型评估的验证模型类
- to_var(x: torch.Tensor) torch.Tensor[源代码]¶
将
x转移到对应的设备上。- 参数:
x (torch.Tensor) – 数据
- 返回:
张量
- 返回类型:
- train_one_step(data: dict[str, Union[dgl.DGLGraph, torch.Tensor]]) float[源代码]¶
根据
data_loader生成的 1 批次(batch)data将 模型训练 1 步。- 参数:
data (dict[str, Union[dgl.DGLGraph , torch.Tensor]]) –
data_loader利用pybind11_ke.data.GraphSampler.sampling()函数生成的数据- 返回:
损失值
- 返回类型: