GraphTester¶
- class pybind11_ke.config.GraphTester(model: RGCN | CompGCN | None = None, data_loader: GraphDataLoader | None = None, sampling_mode: str = 'link_test', prediction: str = 'all', use_gpu: bool = True, device: str = 'cuda:0')[源代码]¶
主要用于
R-GCN[SKB+18] 模型的评估。例子:
from pybind11_ke.data import CompGCNSampler, CompGCNTestSampler, GraphDataLoader from pybind11_ke.module.model import CompGCN from pybind11_ke.module.loss import Cross_Entropy_Loss from pybind11_ke.module.strategy import CompGCNSampling from pybind11_ke.config import GraphTrainer, GraphTester dataloader = GraphDataLoader( in_path = "../../benchmarks/FB15K237/", batch_size = 2048, test_batch_size = 256, num_workers = 16, train_sampler = CompGCNSampler, test_sampler = CompGCNTestSampler ) # define the model compgcn = CompGCN( ent_tol = dataloader.train_sampler.ent_tol, rel_tol = dataloader.train_sampler.rel_tol, dim = 100 ) # define the loss function model = CompGCNSampling( model = compgcn, loss = Cross_Entropy_Loss(model = compgcn), ent_tol = dataloader.train_sampler.ent_tol ) # test the model tester = GraphTester(model = compgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0', prediction = "tail") # train the model trainer = GraphTrainer(model = model, data_loader = dataloader.train_dataloader(), epochs = 2000, lr = 0.0001, use_gpu = True, device = 'cuda:0', tester = tester, test = True, valid_interval = 50, log_interval = 50, save_interval = 50, save_path = '../../checkpoint/compgcn.pth' ) trainer.run()
- __init__(model: RGCN | CompGCN | None = None, data_loader: GraphDataLoader | None = None, sampling_mode: str = 'link_test', prediction: str = 'all', use_gpu: bool = True, device: str = 'cuda:0')[源代码]¶
创建 Tester 对象。
- 参数:
model (
pybind11_ke.module.model.RGCNorpybind11_ke.module.model.CompGCN) – RGCN or CompGCNdata_loader (
pybind11_ke.data.GraphDataLoader) – GraphDataLoadersampling_mode (str) – 评估验证集还是测试集:’link_test’ or ‘link_valid’
prediction (str) – 链接预测模式: ‘all’、’head’、’tail’
use_gpu (bool) – 是否使用 gpu
device (str) – 使用哪个 gpu
- __weakref__¶
list of weak references to the object (if defined)
- data_loader: TestDataLoader | GraphDataLoader | None¶
pybind11_ke.data.TestDataLoaderorpybind11_ke.data.GraphDataLoader
- device: torch.device¶
gpu,利用
device构造的torch.device对象
- model: Union[Model, None]¶
KGE 模型,即
pybind11_ke.module.model.Model
- sampling_mode: str¶
pybind11_ke.data.TestDataLoader负采样的方式:link_testorlink_valid
- set_sampling_mode(sampling_mode: str)¶
-
- 参数:
sampling_mode (str) – 数据采样模式,
link_test和link_valid分别表示为链接预测进行测试集和验证集的负采样
- test_dataloader: torch.utils.data.DataLoader¶
测试数据加载器。
- test_one_step(data: dict[str, Union[numpy.ndarray, str]]) numpy.ndarray¶
根据
data_loader生成的 1 批次(batch)data将模型验证 1 步。- 参数:
data (dict[str, Union[np.ndarray, str]]) –
data_loader利用pybind11_ke.data.TestDataLoader.sampling()函数生成的数据- 返回:
三元组的得分
- 返回类型:
- to_var(x: torch.Tensor) torch.Tensor[源代码]¶
根据
use_gpu返回x的张量- 参数:
x (torch.Tensor) – 数据
- 返回:
张量
- 返回类型:
- val_dataloader: torch.utils.data.DataLoader¶
验证数据加载器。