Shortcuts

TestDataLoader

class pybind11_ke.data.TestDataLoader(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt', sampling_mode: str = 'link_test', type_constrain: bool = True)[源代码]

主要从底层 C++ 模块获得数据用于 KGE 模型的评估。

例子:

from pybind11_ke.config import Tester
from pybind11_ke.data import TestDataLoader

# dataloader for test
test_dataloader = TestDataLoader('../../benchmarks/FB15K/')

# test the model
tester = Tester(model = transe, data_loader = test_dataloader, use_gpu = True, device = 'cuda:1')
__init__(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt', sampling_mode: str = 'link_test', type_constrain: bool = True)[源代码]

创建 TestDataLoader 对象。

参数:
  • in_path (str) – 数据集目录

  • ent_file (str) – entity2id.txt

  • rel_file (str) – relation2id.txt

  • train_file (str) – train2id.txt

  • valid_file (str) – valid2id.txt

  • test_file (str) – test2id.txt

  • sampling_mode (str) – 数据采样模式,link_testlink_valid 分别表示为链接预测进行测试集和验证集的负采样

  • type_constrain (bool) – 是否用 type_constrain.txt 进行负采样

__iter__() TestDataSampler[源代码]

迭代器函数 iterator.__iter__(), 根据 sampling_mode 决定是评估验证集还是测试集。

__len__() int[源代码]

len() 要求 object.__len__()

返回:

test_tolvalid_tol

返回类型:

int

__weakref__

list of weak references to the object (if defined)

ent_file: str

entity2id.txt

ent_tol: int

实体的个数

get_ent_tol() int[源代码]

返回 ent_tol

返回:

ent_tol

返回类型:

int

get_rel_tol() int[源代码]

返回 rel_tol

返回:

rel_tol

返回类型:

int

get_test_tol() int[源代码]

返回 test_tol

返回:

test_tol

返回类型:

int

get_valid_tol() int[源代码]

返回 test_tol

返回:

test_tol

返回类型:

int

in_path: str

数据集目录

read()[源代码]

利用 pybind11 让底层 C++ 模块读取数据集中的数据

rel_file: str

relation2id.txt

rel_tol: int

关系的个数

sampling() dict[str, Union[numpy.ndarray, str]][源代码]

为链接预测进行采样数据,为给定的正三元组,用所有实体依次替换头尾实体得到 2 * ent_tol 个三元组。

返回:

对于一个正三元组生成的所有可能破化的三元组

返回类型:

dict[str, Union[np.ndarray, str]]

sampling_mode: str

数据采样模式,link_testlink_valid 分别表示为链接预测进行测试集和验证集的负采样

set_sampling_mode(sampling_mode: str)[源代码]

设置 sampling_mode

参数:

sampling_mode (str) – 数据采样模式,link_testlink_valid 分别表示为链接预测进行测试集和验证集的负采样

test_file: str

test2id.txt

test_tol: int

测试集三元组的个数

train_file: str

train2id.txt

type_constrain: bool

是否用 type_constrain.txt 进行负采样

valid_file: str

valid2id.txt

valid_tol: int

验证集三元组的个数

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs