Shortcuts

Tester

class pybind11_ke.config.Tester(model: Model | None = None, data_loader: TestDataLoader | GraphDataLoader | None = None, sampling_mode: str = 'link_test', use_gpu: bool = True, device: str = 'cuda:0')[源代码]

主要用于 KGE 模型的评估。

例子:

from pybind11_ke.config import Trainer, Tester

# test the model
transe.load_checkpoint('../checkpoint/transe.ckpt')
tester = Tester(model = transe, data_loader = test_dataloader, use_gpu = True)
tester.run_link_prediction()
__init__(model: Model | None = None, data_loader: TestDataLoader | GraphDataLoader | None = None, sampling_mode: str = 'link_test', use_gpu: bool = True, device: str = 'cuda:0')[源代码]

创建 Tester 对象。

参数:
__weakref__

list of weak references to the object (if defined)

data_loader: TestDataLoader | GraphDataLoader | None

pybind11_ke.data.TestDataLoader or pybind11_ke.data.GraphDataLoader

device: torch.device

gpu,利用 device 构造的 torch.device 对象

model: Model | None

KGE 模型,即 pybind11_ke.module.model.Model

进行链接预测。

返回:

经典指标分别为 MR,MRR,Hits@1,Hits@3,Hits@10

返回类型:

tuple[float, …]

sampling_mode: str

pybind11_ke.data.TestDataLoader 负采样的方式:link_test or link_valid

set_sampling_mode(sampling_mode: str)[源代码]

设置 sampling_mode

参数:

sampling_mode (str) – 数据采样模式,link_testlink_valid 分别表示为链接预测进行测试集和验证集的负采样

test_one_step(data: dict[str, Union[numpy.ndarray, str]]) numpy.ndarray[源代码]

根据 data_loader 生成的 1 批次(batch) data 将模型验证 1 步。

参数:

data (dict[str, Union[np.ndarray, str]]) – data_loader 利用 pybind11_ke.data.TestDataLoader.sampling() 函数生成的数据

返回:

三元组的得分

返回类型:

numpy.ndarray

to_var(x: numpy.ndarray, use_gpu: bool) torch.Tensor[源代码]

根据 use_gpu 返回 x 的张量

参数:
返回:

张量

返回类型:

torch.Tensor

use_gpu: bool

是否使用 gpu

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs