KGEDataLoader¶
- class pybind11_ke.data.KGEDataLoader(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt', batch_size: int | None = None, neg_ent: int = 1, test: bool = False, test_batch_size: int | None = None, type_constrain: bool = True, num_workers: int | None = None, train_sampler: ~typing.Type[~pybind11_ke.data.UniSampler.UniSampler] | ~typing.Type[~pybind11_ke.data.BernSampler.BernSampler] | ~typing.Type[~pybind11_ke.data.RGCNSampler.RGCNSampler] | ~typing.Type[~pybind11_ke.data.CompGCNSampler.CompGCNSampler] = <class 'pybind11_ke.data.BernSampler.BernSampler'>, test_sampler: ~typing.Type[~pybind11_ke.data.TestSampler.TestSampler] = <class 'pybind11_ke.data.TradTestSampler.TradTestSampler'>)[源代码]¶
KGE 模型数据加载器。
例子:
from pybind11_ke.data import KGEDataLoader, BernSampler, TradTestSampler dataloader = KGEDataLoader( in_path = "../../benchmarks/FB15K/", batch_size = 8192, neg_ent = 25, test = True, test_batch_size = 256, num_workers = 16, train_sampler = BernSampler, test_sampler = TradTestSampler )
- __init__(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', valid_file: str = 'valid2id.txt', test_file: str = 'test2id.txt', batch_size: int | None = None, neg_ent: int = 1, test: bool = False, test_batch_size: int | None = None, type_constrain: bool = True, num_workers: int | None = None, train_sampler: ~typing.Type[~pybind11_ke.data.UniSampler.UniSampler] | ~typing.Type[~pybind11_ke.data.BernSampler.BernSampler] | ~typing.Type[~pybind11_ke.data.RGCNSampler.RGCNSampler] | ~typing.Type[~pybind11_ke.data.CompGCNSampler.CompGCNSampler] = <class 'pybind11_ke.data.BernSampler.BernSampler'>, test_sampler: ~typing.Type[~pybind11_ke.data.TestSampler.TestSampler] = <class 'pybind11_ke.data.TradTestSampler.TradTestSampler'>)[源代码]¶
创建 KGEDataLoader 对象。
- 参数:
in_path (str) – 数据集目录
ent_file (str) – entity2id.txt
rel_file (str) – relation2id.txt
train_file (str) – train2id.txt
valid_file (str) – valid2id.txt
test_file (str) – test2id.txt
batch_size (int | None) – batch size
neg_ent (int) – 对于每一个正三元组, 构建的负三元组的个数, 替换 entity;对于 CompGCN 不起作用。
test (bool) – 是否读取验证集和测试集
test_batch_size (int | None) – test batch size
type_constrain (bool) – 是否报告 type_constrain.txt 限制的测试结果
num_workers (int) – 加载数据的进程数
train_sampler (Union[Type[UniSampler], Type[BernSampler], Type[RGCNSampler], Type[CompGCNSampler]]) – 训练数据采样器
test_sampler (Type[TestSampler]) – 测试数据采样器
- __weakref__¶
list of weak references to the object (if defined)
- test_dataloader() torch.utils.data.DataLoader[源代码]¶
返回测试数据加载器。
- 返回:
测试数据加载器
- 返回类型:
- test_sampler: TestSampler¶
测试数据采样器
- train_dataloader() torch.utils.data.DataLoader[源代码]¶
返回训练数据加载器。
- 返回:
训练数据加载器
- 返回类型:
- train_sampler: UniSampler | BernSampler | RGCNSampler | CompGCNSampler¶
训练数据采样器
- val_dataloader() torch.utils.data.DataLoader[源代码]¶
返回验证数据加载器。
- 返回:
验证数据加载器
- 返回类型: