Shortcuts

pybind11_ke.data.GraphDataLoader 源代码

# coding:utf-8
#
# pybind11_ke/data/GraphDataLoader.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2024
#
# 为图神经网络读取数据.

"""
GraphDataLoader - 图神经网络读取数据集类。
"""

import typing
from .GraphSampler import GraphSampler
from .CompGCNSampler import CompGCNSampler
from .GraphTestSampler import GraphTestSampler
from .CompGCNTestSampler import CompGCNTestSampler
from torch.utils.data import DataLoader

[文档]class GraphDataLoader: """基本图神经网络采样器。 例子:: from pybind11_ke.data import CompGCNSampler, CompGCNTestSampler, GraphDataLoader dataloader = GraphDataLoader( in_path = "../../benchmarks/FB15K237/", batch_size = 60000, neg_ent = 10, test = True, test_batch_size = 100, num_workers = 16 ) """
[文档] def __init__( self, in_path: str = "./", ent_file: str = "entity2id.txt", rel_file: str = "relation2id.txt", train_file: str = "train2id.txt", valid_file: str = "valid2id.txt", test_file: str = "test2id.txt", batch_size: int | None = None, neg_ent: int = 1, test: bool = False, test_batch_size: int | None = None, num_workers: int | None = None, train_sampler: typing.Union[typing.Type[GraphSampler], typing.Type[CompGCNSampler]] = GraphSampler, test_sampler: typing.Union[typing.Type[GraphTestSampler], typing.Type[CompGCNTestSampler]] = GraphTestSampler): """创建 GraphDataLoader 对象。 :param in_path: 数据集目录 :type in_path: str :param ent_file: entity2id.txt :type ent_file: str :param rel_file: relation2id.txt :type rel_file: str :param train_file: train2id.txt :type train_file: str :param valid_file: valid2id.txt :type valid_file: str :param test_file: test2id.txt :type test_file: str :param batch_size: batch size :type batch_size: int | None :param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail);对于 CompGCN 不起作用。 :type neg_ent: int :param test: 是否读取验证集和测试集 :type test: bool :param test_batch_size: test batch size :type test_batch_size: int | None :param num_workers: 加载数据的进程数 :type num_workers: int :param train_sampler: 训练数据采样器 :type train_sampler: typing.Union[typing.Type[GraphSampler], typing.Type[CompGCNSampler]] :param test_sampler: 测试数据采样器 :type test_sampler: typing.Union[typing.Type[GraphTestSampler], typing.Type[CompGCNTestSampler]] """ #: 数据集目录 self.in_path: str = in_path #: entity2id.txt self.ent_file: str = ent_file #: relation2id.txt self.rel_file: str = rel_file #: train2id.txt self.train_file: str = train_file #: valid2id.txt self.valid_file: str = valid_file #: test2id.txt self.test_file: str = test_file #: batch size self.batch_size: int = batch_size #: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail);对于 CompGCN 不起作用。 self.neg_ent: int = neg_ent #: 是否读取验证集和测试集 self.test: bool = test #: test batch size self.test_batch_size: int = test_batch_size #: 加载数据的进程数 self.num_workers: int = num_workers #: 训练数据采样器 self.train_sampler: typing.Union[typing.Type[GraphSampler], typing.Type[CompGCNSampler]] = train_sampler( in_path=self.in_path, ent_file=self.ent_file, rel_file=self.rel_file, train_file=self.train_file, batch_size=self.batch_size, neg_ent=self.neg_ent ) #: 训练集三元组 self.data_train: list[tuple[int, int, int]] = self.train_sampler.get_train() if self.test: #: 测试数据采样器 self.test_sampler: typing.Union[typing.Type[GraphTestSampler], typing.Type[CompGCNTestSampler]] = test_sampler( sampler=self.train_sampler, valid_file=self.valid_file, test_file=self.test_file, ) #: 验证集三元组 self.data_val: list[tuple[int, int, int]] = self.test_sampler.get_valid() #: 测试集三元组 self.data_test: list[tuple[int, int, int]] = self.test_sampler.get_test()
[文档] def train_dataloader(self) -> DataLoader: """返回训练数据加载器。 :returns: 训练数据加载器 :rtype: torch.utils.data.DataLoader """ return DataLoader( self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, drop_last=True, collate_fn=self.train_sampler.sampling, )
[文档] def val_dataloader(self) -> DataLoader: """返回验证数据加载器。 :returns: 验证数据加载器 :rtype: torch.utils.data.DataLoader """ return DataLoader( self.data_val, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=self.test_sampler.sampling, )
[文档] def test_dataloader(self) -> DataLoader: """返回测试数据加载器。 :returns: 测试数据加载器 :rtype: torch.utils.data.DataLoader""" return DataLoader( self.data_test, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=self.test_sampler.sampling, )
[文档]def get_graph_data_loader_hpo_config() -> dict[str, dict[str, typing.Any]]: """返回 :py:class:`GraphDataLoader` 的默认超参数优化配置。 默认配置为:: parameters_dict = { 'dataloader': { 'value': 'GraphDataLoader' }, 'in_path': { 'value': './' }, 'ent_file': { 'value': 'entity2id.txt' }, 'rel_file': { 'value': 'relation2id.txt' }, 'train_file': { 'value': 'train2id.txt' }, 'valid_file': { 'value': 'valid2id.txt' }, 'test_file': { 'value': 'test2id.txt' }, 'batch_size': { 'values': [512, 1024, 2048, 4096] }, 'neg_ent': { 'values': [1, 4, 8, 16] }, 'test_batch_size': { 'value': 100 }, 'num_workers': { 'value': 16 }, 'train_sampler': { 'value': 'GraphSampler' }, 'test_sampler': { 'value': 'GraphTestSampler' } } :returns: :py:class:`GraphDataLoader` 的默认超参数优化配置 :rtype: dict[str, dict[str, typing.Any]] """ parameters_dict = { 'dataloader': { 'value': 'GraphDataLoader' }, 'in_path': { 'value': './' }, 'ent_file': { 'value': 'entity2id.txt' }, 'rel_file': { 'value': 'relation2id.txt' }, 'train_file': { 'value': 'train2id.txt' }, 'valid_file': { 'value': 'valid2id.txt' }, 'test_file': { 'value': 'test2id.txt' }, 'batch_size': { 'values': [512, 1024, 2048, 4096] }, 'neg_ent': { 'values': [1, 4, 8, 16] }, 'test_batch_size': { 'value': 100 }, 'num_workers': { 'value': 16 }, 'train_sampler': { 'value': 'GraphSampler' }, 'test_sampler': { 'value': 'GraphTestSampler' } } return parameters_dict

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs