RGCNSampler¶
- class pybind11_ke.data.RGCNSampler(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, neg_ent: int = 1)[源代码]¶
R-GCN[SKB+18] 的训练数据采样器。例子:
from pybind11_ke.data import RGCNSampler, CompGCNSampler from torch.utils.data import DataLoader #: 训练数据采样器 train_sampler: typing.Union[typing.Type[RGCNSampler], typing.Type[CompGCNSampler]] = train_sampler( in_path=in_path, ent_file=ent_file, rel_file=rel_file, train_file=train_file, batch_size=batch_size, neg_ent=neg_ent ) #: 训练集三元组 data_train: list[tuple[int, int, int]] = train_sampler.get_train() train_dataloader = DataLoader( data_train, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True, drop_last=True, collate_fn=train_sampler.sampling, )
- __init__(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, neg_ent: int = 1)[源代码]¶
创建 RGCNSampler 对象。
- __weakref__¶
list of weak references to the object (if defined)
- add_reverse_relation()¶
增加相反关系:r` = r + rel_tol
- add_train_reverse_triples()¶
对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。
- build_graph(num_ent: int, triples: tuple[torch.Tensor, torch.Tensor, torch.Tensor], power: int = -1) tuple[dgl.DGLGraph, torch.Tensor, torch.Tensor][源代码]¶
建立子图。
- 参数:
num_ent (int) – 子图的节点数
triples (tuple[torch.Tensor, torch.Tensor, torch.Tensor]) – 知识图谱中的正确三元组子集
power (int) – 幂
- 返回:
子图、关系、边的归一化系数
- 返回类型:
- comp_deg_norm(graph: dgl.DGLGraph, power: int = -1) torch.Tensor[源代码]¶
根据目标节点度计算目标节点的归一化系数。
- 参数:
graph (dgl.DGLGraph) – 子图
power (int) – 幂
- 返回:
节点的归一化系数
- 返回类型:
- get_hr2t_rt2h_from_train()¶
获得
hr2t_train和rt2h_train。
- get_train_triples_id()¶
读取
train_file文件。
- hr2t_train: collections.defaultdict[set]¶
训练集中所有 h-r 对对应的 t 集合
- node_norm_to_edge_norm(graph: dgl.DGLGraph, node_norm: torch.Tensor) torch.Tensor[源代码]¶
根据目标节点度计算每条边的归一化系数。
- 参数:
graph (dgl.DGLGraph) – 子图
node_norm (torch.Tensor) – 节点的归一化系数
- 返回:
边的归一化系数
- 返回类型:
- rt2h_train: collections.defaultdict[set]¶
训练集中所有 r-t 对对应的 h 集合
- sampling(pos_triples: list[tuple[int, int, int]]) dict[str, Union[dgl.DGLGraph, torch.Tensor]][源代码]¶
R-GCN[SKB+18] 的采样函数。
- sampling_positive(positive_triples: list[tuple[int, int, int]]) tuple[numpy.ndarray, torch.Tensor][源代码]¶
为创建子图重新采样三元组子集,重排实体 ID。
- 参数:
- 返回:
三元组子集和原始的实体 ID
- 返回类型: