Tester¶
- class pybind11_ke.config.Tester(model: Model | None = None, data_loader: TestDataLoader | GraphDataLoader | None = None, sampling_mode: str = 'link_test', use_gpu: bool = True, device: str = 'cuda:0')[源代码]¶
主要用于 KGE 模型的评估。
例子:
from pybind11_ke.config import Trainer, Tester # test the model transe.load_checkpoint('../checkpoint/transe.ckpt') tester = Tester(model = transe, data_loader = test_dataloader, use_gpu = True) tester.run_link_prediction()
- __init__(model: Model | None = None, data_loader: TestDataLoader | GraphDataLoader | None = None, sampling_mode: str = 'link_test', use_gpu: bool = True, device: str = 'cuda:0')[源代码]¶
创建 Tester 对象。
- 参数:
model (
pybind11_ke.module.model.Model) – KGE 模型data_loader (
pybind11_ke.data.TestDataLoaderorpybind11_ke.data.GraphDataLoader) – TestDataLoader or GraphDataLoadersampling_mode (str) –
pybind11_ke.data.TestDataLoader负采样的方式:link_testorlink_validuse_gpu (bool) – 是否使用 gpu
device (str) – 使用哪个 gpu
- __weakref__¶
list of weak references to the object (if defined)
- data_loader: TestDataLoader | GraphDataLoader | None¶
pybind11_ke.data.TestDataLoaderorpybind11_ke.data.GraphDataLoader
- device: torch.device¶
gpu,利用
device构造的torch.device对象
- model: Model | None¶
KGE 模型,即
pybind11_ke.module.model.Model
- sampling_mode: str¶
pybind11_ke.data.TestDataLoader负采样的方式:link_testorlink_valid
- set_sampling_mode(sampling_mode: str)[源代码]¶
-
- 参数:
sampling_mode (str) – 数据采样模式,
link_test和link_valid分别表示为链接预测进行测试集和验证集的负采样
- test_one_step(data: dict[str, Union[numpy.ndarray, str]]) numpy.ndarray[源代码]¶
根据
data_loader生成的 1 批次(batch)data将模型验证 1 步。- 参数:
data (dict[str, Union[np.ndarray, str]]) –
data_loader利用pybind11_ke.data.TestDataLoader.sampling()函数生成的数据- 返回:
三元组的得分
- 返回类型:
- to_var(x: numpy.ndarray, use_gpu: bool) torch.Tensor[源代码]¶
根据
use_gpu返回x的张量- 参数:
x (numpy.ndarray) – 数据
use_gpu (bool) – 是否使用 gpu
- 返回:
张量
- 返回类型: