Shortcuts

GraphTestSampler

class pybind11_ke.data.GraphTestSampler(sampler: GraphSampler | CompGCNSampler, valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt')[源代码]

R-GCN [SKB+18] 的测试数据采样器。

例子:

from pybind11_ke.data import GraphTestSampler, CompGCNTestSampler
from torch.utils.data import DataLoader

#: 测试数据采样器
test_sampler: typing.Union[typing.Type[GraphTestSampler], typing.Type[CompGCNTestSampler]] = test_sampler(
    sampler=train_sampler,
    valid_file=valid_file,
    test_file=test_file,
)

#: 验证集三元组
data_val: list[tuple[int, int, int]] = test_sampler.get_valid()
#: 测试集三元组
data_test: list[tuple[int, int, int]] = test_sampler.get_test()

val_dataloader = DataLoader(
    data_val,
    shuffle=False,
    batch_size=test_batch_size,
    num_workers=num_workers,
    pin_memory=True,
    collate_fn=test_sampler.sampling,
)

test_dataloader = DataLoader(
    data_test,
    shuffle=False,
    batch_size=test_batch_size,
    num_workers=num_workers,
    pin_memory=True,
    collate_fn=test_sampler.sampling,
)
__init__(sampler: GraphSampler | CompGCNSampler, valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt')[源代码]

创建 GraphTestSampler 对象。

参数:
__weakref__

list of weak references to the object (if defined)

add_valid_test_reverse_triples()[源代码]

对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。

all_true_triples: set[tuple[int, int, int]]

知识图谱所有三元组

ent_tol: int

实体的个数

get_all_true_triples() set[tuple[int, int, int]][源代码]

返回知识图谱所有三元组。

返回:

all_true_triples

返回类型:

set[tuple[int, int, int]]

get_hr2t_rt2h_from_all()[源代码]

获得 hr2t_allrt2h_all

get_test() list[tuple[int, int, int]][源代码]

返回测试集三元组。

返回:

test_triples

返回类型:

list[tuple[int, int, int]]

get_valid() list[tuple[int, int, int]][源代码]

返回验证集三元组。

返回:

valid_triples

返回类型:

list[tuple[int, int, int]]

get_valid_test_triples_id()[源代码]

读取 valid_file 文件和 test_file 文件。

hr2t_all: defaultdict[set]

知识图谱中所有 h-r 对对应的 t 集合

power: float

rt2h_all: defaultdict[set]

知识图谱中所有 r-t 对对应的 h 集合

sampler: GraphSampler | CompGCNSampler

训练数据采样器

sampling(data: list[tuple[int, int, int]]) dict[str, Union[dgl.DGLGraph, torch.Tensor]][源代码]

R-GCN [SKB+18] 的测试数据采样函数。

参数:

data (list[tuple[int, int, int]]) – 测试的正确三元组

返回:

R-GCN [SKB+18] 的测试数据

返回类型:

dict[str, Union[dgl.DGLGraph , torch.Tensor]]

test_file: str

test2id.txt

test_tol: int

测试集三元组的个数

test_triples: list[tuple[int, int, int]]

测试集三元组

triples: list[tuple[int, int, int]]

训练集三元组

valid_file: str

valid2id.txt

valid_tol: int

验证集三元组的个数

valid_triples: list[tuple[int, int, int]]

验证集三元组

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs