Shortcuts

GraphTester

class pybind11_ke.config.GraphTester(model: RGCN | CompGCN | None = None, data_loader: GraphDataLoader | None = None, sampling_mode: str = 'link_test', prediction: str = 'all', use_gpu: bool = True, device: str = 'cuda:0')[源代码]

主要用于 R-GCN [SKB+18] 模型的评估。

例子:

from pybind11_ke.data import CompGCNSampler, CompGCNTestSampler, GraphDataLoader
from pybind11_ke.module.model import CompGCN
from pybind11_ke.module.loss import Cross_Entropy_Loss
from pybind11_ke.module.strategy import CompGCNSampling
from pybind11_ke.config import GraphTrainer, GraphTester

dataloader = GraphDataLoader(
        in_path = "../../benchmarks/FB15K237/",
        batch_size = 2048,
        test_batch_size = 256,
        num_workers = 16,
        train_sampler = CompGCNSampler,
        test_sampler = CompGCNTestSampler
)

# define the model
compgcn = CompGCN(
        ent_tol = dataloader.train_sampler.ent_tol,
        rel_tol = dataloader.train_sampler.rel_tol,
        dim = 100
)

# define the loss function
model = CompGCNSampling(
        model = compgcn,
        loss = Cross_Entropy_Loss(model = compgcn),
        ent_tol = dataloader.train_sampler.ent_tol
)

# test the model
tester = GraphTester(model = compgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0', prediction = "tail")

# train the model
trainer = GraphTrainer(model = model, data_loader = dataloader.train_dataloader(),
        epochs = 2000, lr = 0.0001, use_gpu = True, device = 'cuda:0',
        tester = tester, test = True, valid_interval = 50, log_interval = 50,
        save_interval = 50, save_path = '../../checkpoint/compgcn.pth'
)
trainer.run()
__init__(model: RGCN | CompGCN | None = None, data_loader: GraphDataLoader | None = None, sampling_mode: str = 'link_test', prediction: str = 'all', use_gpu: bool = True, device: str = 'cuda:0')[源代码]

创建 Tester 对象。

参数:
__weakref__

list of weak references to the object (if defined)

data_loader: TestDataLoader | GraphDataLoader | None

pybind11_ke.data.TestDataLoader or pybind11_ke.data.GraphDataLoader

device: torch.device

gpu,利用 device 构造的 torch.device 对象

model: Union[Model, None]

KGE 模型,即 pybind11_ke.module.model.Model

prediction: str

链接预测模式: ‘all’、’head’、’tail’

进行链接预测。

返回:

经典指标分别为 MR,MRR,Hits@1,Hits@3,Hits@10

返回类型:

tuple[float, …]

sampling_mode: str

pybind11_ke.data.TestDataLoader 负采样的方式:link_test or link_valid

set_sampling_mode(sampling_mode: str)

设置 sampling_mode

参数:

sampling_mode (str) – 数据采样模式,link_testlink_valid 分别表示为链接预测进行测试集和验证集的负采样

test_dataloader: torch.utils.data.DataLoader

测试数据加载器。

test_one_step(data: dict[str, Union[numpy.ndarray, str]]) numpy.ndarray

根据 data_loader 生成的 1 批次(batch) data 将模型验证 1 步。

参数:

data (dict[str, Union[np.ndarray, str]]) – data_loader 利用 pybind11_ke.data.TestDataLoader.sampling() 函数生成的数据

返回:

三元组的得分

返回类型:

numpy.ndarray

to_var(x: torch.Tensor) torch.Tensor[源代码]

根据 use_gpu 返回 x 的张量

参数:

x (torch.Tensor) – 数据

返回:

张量

返回类型:

torch.Tensor

use_gpu: bool

是否使用 gpu

val_dataloader: torch.utils.data.DataLoader

验证数据加载器。

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs