# coding:utf-8
#
# pybind11_ke/data/KGEDataLoader.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2024
#
# 为 KGE 模型读取数据.
"""
KGEDataLoader - KGE 模型读取数据集类。
"""
import typing
from .UniSampler import UniSampler
from .BernSampler import BernSampler
from .RGCNSampler import RGCNSampler
from .CompGCNSampler import CompGCNSampler
from .TestSampler import TestSampler
from .TradTestSampler import TradTestSampler
from torch.utils.data import DataLoader
[文档]class KGEDataLoader:
"""KGE 模型数据加载器。
例子::
from pybind11_ke.data import KGEDataLoader, BernSampler, TradTestSampler
dataloader = KGEDataLoader(
in_path = "../../benchmarks/FB15K/",
batch_size = 8192,
neg_ent = 25,
test = True,
test_batch_size = 256,
num_workers = 16,
train_sampler = BernSampler,
test_sampler = TradTestSampler
)
"""
[文档] def __init__(
self,
in_path: str = "./",
ent_file: str = "entity2id.txt",
rel_file: str = "relation2id.txt",
train_file: str = "train2id.txt",
valid_file: str = "valid2id.txt",
test_file: str = "test2id.txt",
batch_size: int | None = None,
neg_ent: int = 1,
test: bool = False,
test_batch_size: int | None = None,
type_constrain: bool = True,
num_workers: int | None = None,
train_sampler: typing.Union[typing.Type[UniSampler], typing.Type[BernSampler], typing.Type[RGCNSampler], typing.Type[CompGCNSampler]] = BernSampler,
test_sampler: typing.Type[TestSampler] = TradTestSampler):
"""创建 GraphDataLoader 对象。
:param in_path: 数据集目录
:type in_path: str
:param ent_file: entity2id.txt
:type ent_file: str
:param rel_file: relation2id.txt
:type rel_file: str
:param train_file: train2id.txt
:type train_file: str
:param valid_file: valid2id.txt
:type valid_file: str
:param test_file: test2id.txt
:type test_file: str
:param batch_size: batch size
:type batch_size: int | None
:param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity;对于 CompGCN 不起作用。
:type neg_ent: int
:param test: 是否读取验证集和测试集
:type test: bool
:param test_batch_size: test batch size
:type test_batch_size: int | None
:param type_constrain: 是否报告 type_constrain.txt 限制的测试结果
:type type_constrain: bool
:param num_workers: 加载数据的进程数
:type num_workers: int
:param train_sampler: 训练数据采样器
:type train_sampler: typing.Union[typing.Type[UniSampler], typing.Type[BernSampler], typing.Type[RGCNSampler], typing.Type[CompGCNSampler]]
:param test_sampler: 测试数据采样器
:type test_sampler: typing.Type[TestSampler]
"""
#: 数据集目录
self.in_path: str = in_path
#: entity2id.txt
self.ent_file: str = ent_file
#: relation2id.txt
self.rel_file: str = rel_file
#: train2id.txt
self.train_file: str = train_file
#: valid2id.txt
self.valid_file: str = valid_file
#: test2id.txt
self.test_file: str = test_file
#: batch size
self.batch_size: int = batch_size
#: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity;对于 CompGCN 不起作用。
self.neg_ent: int = neg_ent
#: 是否读取验证集和测试集
self.test: bool = test
#: test batch size
self.test_batch_size: int = test_batch_size
#: 是否报告 type_constrain.txt 限制的测试结果
self.type_constrain: bool = type_constrain
#: 加载数据的进程数
self.num_workers: int = num_workers
#: 训练数据采样器
self.train_sampler: typing.Union[UniSampler, BernSampler, RGCNSampler, CompGCNSampler] = train_sampler(
in_path=self.in_path,
ent_file=self.ent_file,
rel_file=self.rel_file,
train_file=self.train_file,
batch_size=self.batch_size,
neg_ent=self.neg_ent
)
#: 训练集三元组
self.data_train: list[tuple[int, int, int]] = self.train_sampler.get_train()
if self.test:
#: 测试数据采样器
self.test_sampler: TestSampler = test_sampler(
sampler=self.train_sampler,
valid_file=self.valid_file,
test_file=self.test_file,
type_constrain=type_constrain
)
#: 验证集三元组
self.data_val: list[tuple[int, int, int]] = self.test_sampler.get_valid()
#: 测试集三元组
self.data_test: list[tuple[int, int, int]] = self.test_sampler.get_test()
[文档] def get_ent_tol(self) -> int:
"""返回实体个数。
:returns: 实体个数
:rtype: int
"""
return self.train_sampler.ent_tol
[文档] def get_rel_tol(self) -> int:
"""返回关系个数。
:returns: 关系个数
:rtype: int
"""
return self.train_sampler.rel_tol
[文档] def train_dataloader(self) -> DataLoader:
"""返回训练数据加载器。
:returns: 训练数据加载器
:rtype: torch.utils.data.DataLoader
"""
return DataLoader(
self.data_train,
shuffle=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
collate_fn=self.train_sampler.sampling,
)
[文档] def val_dataloader(self) -> DataLoader:
"""返回验证数据加载器。
:returns: 验证数据加载器
:rtype: torch.utils.data.DataLoader
"""
return DataLoader(
self.data_val,
shuffle=False,
batch_size=self.test_batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self.test_sampler.sampling,
)
[文档] def test_dataloader(self) -> DataLoader:
"""返回测试数据加载器。
:returns: 测试数据加载器
:rtype: torch.utils.data.DataLoader"""
return DataLoader(
self.data_test,
shuffle=False,
batch_size=self.test_batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self.test_sampler.sampling,
)
[文档]def get_kge_data_loader_hpo_config() -> dict[str, dict[str, typing.Any]]:
"""返回 :py:class:`KGEDataLoader` 的默认超参数优化配置。
默认配置为::
parameters_dict = {
'dataloader': {
'value': 'KGEDataLoader'
},
'in_path': {
'value': './'
},
'ent_file': {
'value': 'entity2id.txt'
},
'rel_file': {
'value': 'relation2id.txt'
},
'train_file': {
'value': 'train2id.txt'
},
'valid_file': {
'value': 'valid2id.txt'
},
'test_file': {
'value': 'test2id.txt'
},
'batch_size': {
'values': [512, 1024, 2048, 4096]
},
'neg_ent': {
'values': [1, 4, 16, 64]
},
'test_batch_size': {
'value': 30
},
'type_constrain': {
'value': True
},
'num_workers': {
'value': 16
},
'train_sampler': {
'value': 'BernSampler'
},
'test_sampler': {
'value': 'TradTestSampler'
}
}
:returns: :py:class:`KGEDataLoader` 的默认超参数优化配置
:rtype: dict[str, dict[str, typing.Any]]
"""
parameters_dict = {
'dataloader': {
'value': 'KGEDataLoader'
},
'in_path': {
'value': './'
},
'ent_file': {
'value': 'entity2id.txt'
},
'rel_file': {
'value': 'relation2id.txt'
},
'train_file': {
'value': 'train2id.txt'
},
'valid_file': {
'value': 'valid2id.txt'
},
'test_file': {
'value': 'test2id.txt'
},
'batch_size': {
'values': [512, 1024, 2048, 4096]
},
'neg_ent': {
'values': [1, 4, 16, 64]
},
'test_batch_size': {
'value': 30
},
'type_constrain': {
'value': True
},
'num_workers': {
'value': 16
},
'train_sampler': {
'value': 'BernSampler'
},
'test_sampler': {
'value': 'TradTestSampler'
}
}
return parameters_dict