Shortcuts

pybind11_ke.data.TrainDataLoader 源代码

# coding:utf-8
#
# pybind11_ke/data/TrainDataLoader.py
#
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 3, 2023
#
# 该脚本定义了为训练循环采样数据的类.

"""
TrainDataLoader - 数据集类,类似 :py:class:`torch.utils.data.DataLoader`。
"""

import base
import typing
import numpy as np
from collections.abc import Callable

[文档]class TrainDataSampler(object): """将 :py:meth:`pybind11_ke.data.TrainDataLoader.sampling` 或 :py:meth:`pybind11_ke.data.TrainDataLoader.cross_sampling` 进行封装。 """
[文档] def __init__( self, nbatches: int, sampler: Callable[[], dict[str, typing.Union[np.ndarray, str]]]): """创建 TrainDataSample 对象。 :param nbatches: 1 epoch 有多少个 batch :type nbatches: int :param sampler: 采样器 :type sampler: :py:meth:`pybind11_ke.data.TrainDataLoader.sampling` 或 :py:meth:`pybind11_ke.data.TrainDataLoader.cross_sampling` """ #: 1 epoch 有多少个 batch self.nbatches: int = nbatches #: :py:meth:`pybind11_ke.data.TrainDataLoader.sampling` #: 或 :py:meth:`pybind11_ke.data.TrainDataLoader.cross_sampling` 函数 self.sampler: Callable[[], dict[str, typing.Union[np.ndarray, str]]] = sampler self.batch: int = 0
[文档] def __iter__(self): """迭代器函数 :py:meth:`iterator.__iter__`""" return self
[文档] def __next__(self) -> dict[str, typing.Union[np.ndarray, str]]: """迭代器函数 :py:meth:`iterator.__next__` :returns: 采样一批数据 :rtype: dict[str, typing.Union[np.ndarray, str]] """ self.batch += 1 if self.batch > self.nbatches: raise StopIteration() return self.sampler()
[文档] def __len__(self) -> int: """len() 要求 :py:meth:`object.__len__` :returns: :py:attr:`nbatches` :rtype: int """ return self.nbatches
[文档]class TrainDataLoader(object): """ 主要从底层 C++ 模块获得数据用于 KGE 模型的训练。 例子:: from pybind11_ke.config import Trainer from pybind11_ke.module.model import TransE from pybind11_ke.module.loss import MarginLoss from pybind11_ke.module.strategy import NegativeSampling from pybind11_ke.data import TrainDataLoader # dataloader for training train_dataloader = TrainDataLoader( in_path = "../../benchmarks/FB15K/", nbatches = 200, threads = 8, sampling_mode = "normal", bern = False, neg_ent = 25, neg_rel = 0) # define the model transe = TransE( ent_tol = train_dataloader.get_ent_tol(), rel_tol = train_dataloader.get_rel_tol(), dim = 50, p_norm = 1, norm_flag = True) # define the loss function model = NegativeSampling( model = transe, loss = MarginLoss(margin = 1.0), batch_size = train_dataloader.get_batch_size() ) # train the model trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1', tester = tester, test = True, valid_interval = 100, log_interval = 100, save_interval = 100, save_path = '../../checkpoint/transe.pth') trainer.run() """
[文档] def __init__( self, in_path: str = "./", ent_file: str = "entity2id.txt", rel_file: str = "relation2id.txt", train_file: str = "train2id.txt", batch_size: int | None = None, nbatches: int | None = None, threads: int = 8, sampling_mode: str = "normal", bern: bool = True, neg_ent: int = 1, neg_rel: int = 0): """创建 TrainDataLoader 对象。 :param in_path: 数据集目录 :type in_path: str :param ent_file: entity2id.txt :type ent_file: str :param rel_file: relation2id.txt :type rel_file: str :param train_file: train2id.txt :type train_file: str :param batch_size: batch_size 可以根据 nbatches 计算得出,两者不可以同时不提供;同时指定时 batch_size 优先级更高 :type batch_size: int :param nbatches: nbatches 可以根据 batch_size 计算得出,两者不可以同时不提供;同时指定时 batch_size 优先级更高 :type nbatches: int :param threads: 底层 C++ 数据处理所需要的线程数 :type threads: int :param sampling_mode: 数据采样模式,``normal`` 表示正常负采样,``cross`` 表示交替替换 head 和 tail 进行负采样 :type sampling_mode: str :param bern: 是否使用 TransH 提出的负采样方法进行负采样 :type bern: bool :param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail) :type neg_ent: int :param neg_rel: 对于每一个正三元组, 构建的负三元组的个数, 替换 relation :type neg_rel: int """ #: 数据集目录 self.in_path: str = in_path #: entity2id.txt self.ent_file: str = ent_file #: relation2id.txt self.rel_file: str = rel_file #: train2id.txt self.train_file: str = train_file #: batch_size 可以根据 nbatches 计算得出,两者不可以同时不提供;同时指定时 batch_size 优先级更高 self.batch_size: int = batch_size #: nbatches 可以根据 batch_size 计算得出,两者不可以同时不提供;同时指定时 batch_size 优先级更高 self.nbatches: int = nbatches #: 底层 C++ 数据处理所需要的线程数 self.threads: int = threads #: 数据采样模式,``normal`` 表示正常负采样,``cross`` 表示交替替换 head 和 tail 进行负采样 self.sampling_mode: str = sampling_mode #: 是否使用 TransH 提出的负采样方法进行负采样 self.bern: bool = bern #: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail) self.neg_ent: int = neg_ent #: 对于每一个正三元组, 构建的负三元组的个数, 替换 relation self.neg_rel: int = neg_rel #: 实体的个数 self.ent_tol: int = 0 #: 关系的个数 self.rel_tol: int = 0 #: 训练集三元组的个数 self.train_tot: int = 0 self.cross_sampling_flag: int = 0 # 读入数据 self.read()
[文档] def read(self): """利用 `pybind11 <https://github.com/pybind/pybind11>`__ 让底层 C++ 模块读取数据集中的数据""" print("\nStart reading training data...") base.set_in_path(self.in_path) base.set_ent_path(self.ent_file) base.set_rel_path(self.rel_file) base.set_train_path(self.train_file) base.set_bern(self.bern) base.set_work_threads(self.threads) base.rand_reset() base.read_train_files() print("Training data read completed.\n") # 实体的个数 self.ent_tol = base.get_entity_total() # 关系的个数 self.rel_tol = base.get_relation_total() # 训练集三元组的个数 self.train_tot = base.get_train_total() if self.batch_size == None and self.nbatches == None: raise ValueError("batch_size or nbatches must specify one") elif self.batch_size == None: self.batch_size = self.train_tot // self.nbatches else: self.nbatches = self.train_tot // self.batch_size self.batch_seq_size = self.batch_size * (1 + self.neg_ent + self.neg_rel) # 利用 np.zeros 分配内存 self.batch_h = np.zeros(self.batch_seq_size, dtype=np.int64) self.batch_t = np.zeros(self.batch_seq_size, dtype=np.int64) self.batch_r = np.zeros(self.batch_seq_size, dtype=np.int64) self.batch_y = np.zeros(self.batch_seq_size, dtype=np.float32)
[文档] def sampling(self) -> dict[str, typing.Union[np.ndarray, str]]: """正常采样1 batch 数据,即 ``normal`` :returns: 1 batch 数据 :rtype: dict[str, typing.Union[np.ndarray, str]] """ base.sampling(self.batch_h, self.batch_t, self.batch_r, self.batch_y, self.batch_size, self.neg_ent, self.neg_rel, 0) return { "batch_h": self.batch_h, "batch_t": self.batch_t, "batch_r": self.batch_r, "batch_y": self.batch_y, "mode": "normal" }
[文档] def sampling_head(self) -> dict[str, typing.Union[np.ndarray, str]]: """只替换 head 进行负采样, 生成 1 batch 数据 :returns: 1 batch 数据 :rtype: dict[str, typing.Union[np.ndarray, str]] """ base.sampling(self.batch_h, self.batch_t, self.batch_r, self.batch_y, self.batch_size, self.neg_ent, self.neg_rel, 1) return { "batch_h": self.batch_h, "batch_t": self.batch_t[:self.batch_size], "batch_r": self.batch_r[:self.batch_size], "batch_y": self.batch_y, "mode": "head_batch" }
[文档] def sampling_tail(self) -> dict[str, typing.Union[np.ndarray, str]]: """只替换 tail 进行负采样, 生成 1 batch 数据 :returns: 1 batch 数据 :rtype: dict[str, typing.Union[np.ndarray, str]] """ base.sampling(self.batch_h, self.batch_t, self.batch_r, self.batch_y, self.batch_size, self.neg_ent, self.neg_rel, -1) return { "batch_h": self.batch_h[:self.batch_size], "batch_t": self.batch_t, "batch_r": self.batch_r[:self.batch_size], "batch_y": self.batch_y, "mode": "tail_batch" }
[文档] def cross_sampling(self) -> dict[str, typing.Union[np.ndarray, str]]: """交替替换 head 和 tail 进行负采样, 生成 1 batch 数据 :returns: 1 batch 数据 :rtype: dict[str, typing.Union[np.ndarray, str]] """ self.cross_sampling_flag = 1 - self.cross_sampling_flag if self.cross_sampling_flag == 0: return self.sampling_head() else: return self.sampling_tail()
[文档] def get_batch_size(self) -> int: """返回 :py:attr:`batch_size` :returns: :py:attr:`batch_size` :rtype: int """ return self.batch_size
[文档] def get_ent_tol(self) -> int: """返回 :py:attr:`ent_tol` :returns: :py:attr:`ent_tol` :rtype: int """ return self.ent_tol
[文档] def get_rel_tol(self) -> int: """返回 :py:attr:`rel_tol` :returns: :py:attr:`rel_tol` :rtype: int """ return self.rel_tol
[文档] def get_train_tot(self) -> int: """返回 :py:attr:`train_tot` :returns: :py:attr:`train_tot` :rtype: int """ return self.train_tot
[文档] def __iter__(self) -> TrainDataSampler: """迭代器函数 :py:meth:`iterator.__iter__`, 根据 :py:attr:`sampling_mode` 选择返回 :py:meth:`sampling` 和 :py:meth:`cross_sampling`""" if self.sampling_mode == "normal": return TrainDataSampler(self.nbatches, self.sampling) elif self.sampling_mode == "cross": return TrainDataSampler(self.nbatches, self.cross_sampling)
[文档] def __len__(self) -> int: """len() 要求 :py:meth:`object.__len__` :returns: :py:attr:`nbatches` :rtype: int """ return self.nbatches
[文档]def get_train_data_loader_hpo_config() -> dict[str, dict[str, typing.Any]]: """返回 :py:class:`TrainDataLoader` 的默认超参数优化配置。 默认配置为:: parameters_dict = { 'dataloader': { 'value': 'TrainDataLoader' }, 'in_path': { 'value': './' }, 'ent_file': { 'value': 'entity2id.txt' }, 'rel_file': { 'value': 'relation2id.txt' }, 'train_file': { 'value': 'train2id.txt' }, 'batch_size': { 'values': [512, 1024, 2048, 4096] }, 'threads': { 'value': 4 }, 'sampling_mode': { 'value': 'normal' }, 'bern': { 'value': True }, 'neg_ent': { 'values': [1, 4, 16, 32] }, 'neg_rel': { 'value': 0 } } :returns: :py:class:`TrainDataLoader` 的默认超参数优化配置 :rtype: dict[str, dict[str, typing.Any]] """ parameters_dict = { 'dataloader': { 'value': 'TrainDataLoader' }, 'in_path': { 'value': './' }, 'ent_file': { 'value': 'entity2id.txt' }, 'rel_file': { 'value': 'relation2id.txt' }, 'train_file': { 'value': 'train2id.txt' }, 'batch_size': { 'values': [512, 1024, 2048, 4096] }, 'threads': { 'value': 4 }, 'sampling_mode': { 'value': 'normal' }, 'bern': { 'value': True }, 'neg_ent': { 'values': [1, 4, 16, 32] }, 'neg_rel': { 'value': 0 } } return parameters_dict

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs