Tester¶
- class pybind11_ke.config.Tester(model: Model | None = None, data_loader: KGEDataLoader | None = None, sampling_mode: str = 'link_test', prediction: str = 'all', use_tqdm: bool = True, use_gpu: bool = True, device: str = 'cuda:0', only_test: bool = False)[源代码]¶
主要用于 KGE 模型的评估。
例子:
from pybind11_ke.data import KGEDataLoader, BernSampler, TradTestSampler from pybind11_ke.module.model import TransE from pybind11_ke.module.loss import MarginLoss from pybind11_ke.module.strategy import NegativeSampling from pybind11_ke.config import Trainer, Tester # dataloader for training dataloader = KGEDataLoader( in_path = "../../benchmarks/FB15K/", batch_size = 8192, neg_ent = 25, test = True, test_batch_size = 256, num_workers = 16, train_sampler = BernSampler, test_sampler = TradTestSampler ) # define the model transe = TransE( ent_tol = dataloader.train_sampler.ent_tol, rel_tol = dataloader.train_sampler.rel_tol, dim = 50, p_norm = 1, norm_flag = True) # define the loss function model = NegativeSampling( model = transe, loss = MarginLoss(margin = 1.0), regul_rate = 0.01 ) # test the model tester = Tester(model = transe, data_loader = dataloader, use_gpu = True, device = 'cuda:1') # train the model trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(), epochs = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1', tester = tester, test = True, valid_interval = 100, log_interval = 100, save_interval = 100, save_path = '../../checkpoint/transe.pth', delta = 0.01) trainer.run()
- __init__(model: Model | None = None, data_loader: KGEDataLoader | None = None, sampling_mode: str = 'link_test', prediction: str = 'all', use_tqdm: bool = True, use_gpu: bool = True, device: str = 'cuda:0', only_test: bool = False)[源代码]¶
创建 Tester 对象。
- 参数:
model (pybind11_ke.module.model.Model) – KGE 模型
data_loader (pybind11_ke.data.KGEDataLoader) – py:class:pybind11_ke.data.KGEDataLoader
sampling_mode (str) – 评估验证集还是测试集:’link_test’ or ‘link_valid’
prediction (str) – 链接预测模式: ‘all’、’head’、’tail’
use_tqdm (bool) – 是否启用进度条
use_gpu (bool) – 是否使用 gpu
device (str) – 使用哪个 gpu
only_test (bool) – 是否是评估已经训练好的模型
- __weakref__¶
list of weak references to the object (if defined)
- data_loader: KGEDataLoader¶
- device: torch.device¶
gpu,利用
device构造的torch.device对象
- model: Model¶
KGE 模型,即
pybind11_ke.module.model.Model
- set_sampling_mode(sampling_mode: str)[源代码]¶
-
- 参数:
sampling_mode (str) – 数据采样模式,’link_test’ 和 ‘link_valid’ 分别表示为链接预测进行测试集和验证集的负采样
- test_dataloader: torch.utils.data.DataLoader¶
测试数据加载器。
- to_var(x: torch.Tensor) torch.Tensor[源代码]¶
根据
use_gpu返回x的张量- 参数:
x (torch.Tensor) – 数据
- 返回:
张量
- 返回类型:
- val_dataloader: torch.utils.data.DataLoader¶
验证数据加载器。