GraphDataLoader¶
- class pybind11_ke.data.GraphDataLoader(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt', batch_size: int | None = None, neg_ent: int = 1, test: bool = False, test_batch_size: int | None = None, num_workers: int | None = None, train_sampler: ~typing.Type[~pybind11_ke.data.GraphSampler.GraphSampler] | ~typing.Type[~pybind11_ke.data.CompGCNSampler.CompGCNSampler] = <class 'pybind11_ke.data.GraphSampler.GraphSampler'>, test_sampler: ~typing.Type[~pybind11_ke.data.GraphTestSampler.GraphTestSampler] | ~typing.Type[~pybind11_ke.data.CompGCNTestSampler.CompGCNTestSampler] = <class 'pybind11_ke.data.GraphTestSampler.GraphTestSampler'>)[源代码]¶
基本图神经网络采样器。
例子:
from pybind11_ke.data import CompGCNSampler, CompGCNTestSampler, GraphDataLoader dataloader = GraphDataLoader( in_path = "../../benchmarks/FB15K237/", batch_size = 60000, neg_ent = 10, test = True, test_batch_size = 100, num_workers = 16 )
- __init__(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt', batch_size: int | None = None, neg_ent: int = 1, test: bool = False, test_batch_size: int | None = None, num_workers: int | None = None, train_sampler: ~typing.Type[~pybind11_ke.data.GraphSampler.GraphSampler] | ~typing.Type[~pybind11_ke.data.CompGCNSampler.CompGCNSampler] = <class 'pybind11_ke.data.GraphSampler.GraphSampler'>, test_sampler: ~typing.Type[~pybind11_ke.data.GraphTestSampler.GraphTestSampler] | ~typing.Type[~pybind11_ke.data.CompGCNTestSampler.CompGCNTestSampler] = <class 'pybind11_ke.data.GraphTestSampler.GraphTestSampler'>)[源代码]¶
创建 GraphDataLoader 对象。
- 参数:
in_path (str) – 数据集目录
ent_file (str) – entity2id.txt
rel_file (str) – relation2id.txt
train_file (str) – train2id.txt
valid_file (str) – valid2id.txt
test_file (str) – test2id.txt
batch_size (int | None) – batch size
neg_ent (int) – 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail);对于 CompGCN 不起作用。
test (bool) – 是否读取验证集和测试集
test_batch_size (int | None) – test batch size
num_workers (int) – 加载数据的进程数
train_sampler (Union[Type[GraphSampler], Type[CompGCNSampler]]) – 训练数据采样器
test_sampler (Union[Type[GraphTestSampler], Type[CompGCNTestSampler]]) – 测试数据采样器
- __weakref__¶
list of weak references to the object (if defined)
- test_dataloader() torch.utils.data.DataLoader[源代码]¶
返回测试数据加载器。
- 返回:
测试数据加载器
- 返回类型:
- test_sampler: Type[GraphTestSampler] | Type[CompGCNTestSampler]¶
测试数据采样器
- train_dataloader() torch.utils.data.DataLoader[源代码]¶
返回训练数据加载器。
- 返回:
训练数据加载器
- 返回类型:
- train_sampler: Type[GraphSampler] | Type[CompGCNSampler]¶
训练数据采样器
- val_dataloader() torch.utils.data.DataLoader[源代码]¶
返回验证数据加载器。
- 返回:
验证数据加载器
- 返回类型: