Shortcuts

pybind11_ke.data.TestDataLoader 源代码

# coding:utf-8
#
# pybind11_ke/data/TestDataLoader.py
#
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 4, 2023
#
# 该脚本定义了采样数据的函数, 用于验证模型.

"""
TrainDataLoader - 数据集类,类似 :py:class:`torch.utils.data.DataLoader`。
"""

import base
import typing
import numpy as np
from ..utils import construct_type_constrain
from collections.abc import Callable

[文档]class TestDataSampler(object): """将 :py:meth:`pybind11_ke.data.TestDataLoader.sampling` 进行封装。 """
[文档] def __init__( self, data_total: int, sampler: Callable[[], dict[str, typing.Union[np.ndarray, str]]]): """创建 TestDataSampler 对象 :param data_total: 测试集多少个三元组 :type data_total: int :param sampler: 采样器 :type sampler: :py:meth:`pybind11_ke.data.TestDataLoader.sampling` """ #: 测试集多少个三元组 self.data_total: int = data_total #: :py:meth:`pybind11_ke.data.TestDataLoader.sampling` 函数 self.sampler: Callable[[], dict[str, typing.Union[np.ndarray, str]]] = sampler self.total: int = 0
[文档] def __iter__(self): """迭代器函数 :py:meth:`iterator.__iter__`""" return self
[文档] def __next__(self) -> dict[str, typing.Union[np.ndarray, str]]: """ 迭代器函数 :py:meth:`iterator.__next__` :returns: 采样一批数据 :rtype: dict[str, typing.Union[np.ndarray, str]] """ self.total += 1 if self.total > self.data_total: raise StopIteration() return self.sampler()
[文档] def __len__(self) -> int: """len() 要求 :py:meth:`object.__len__` :returns: :py:attr:`data_total` :rtype: int """ return self.data_total
[文档]class TestDataLoader(object): """ 主要从底层 C++ 模块获得数据用于 KGE 模型的评估。 例子:: from pybind11_ke.config import Tester from pybind11_ke.data import TestDataLoader # dataloader for test test_dataloader = TestDataLoader('../../benchmarks/FB15K/') # test the model tester = Tester(model = transe, data_loader = test_dataloader, use_gpu = True, device = 'cuda:1') """
[文档] def __init__( self, in_path: str = "./", ent_file: str = "entity2id.txt", rel_file: str = "relation2id.txt", train_file: str = "train2id.txt", valid_file: str = "valid2id.txt", test_file: str = "test2id.txt", sampling_mode: str = 'link_test', type_constrain: bool = True): """创建 TestDataLoader 对象。 :param in_path: 数据集目录 :type in_path: str :param ent_file: entity2id.txt :type ent_file: str :param rel_file: relation2id.txt :type rel_file: str :param train_file: train2id.txt :type train_file: str :param valid_file: valid2id.txt :type valid_file: str :param test_file: test2id.txt :type test_file: str :param sampling_mode: 数据采样模式,``link_test`` 和 ``link_valid`` 分别表示为链接预测进行测试集和验证集的负采样 :type sampling_mode: str :param type_constrain: 是否用 type_constrain.txt 进行负采样 :type type_constrain: bool """ #: 数据集目录 self.in_path: str = in_path #: entity2id.txt self.ent_file: str = ent_file #: relation2id.txt self.rel_file: str = rel_file #: train2id.txt self.train_file: str = train_file #: valid2id.txt self.valid_file: str = valid_file #: test2id.txt self.test_file: str = test_file #: 数据采样模式,``link_test`` 和 ``link_valid`` 分别表示为链接预测进行测试集和验证集的负采样 self.sampling_mode: str = sampling_mode #: 是否用 type_constrain.txt 进行负采样 self.type_constrain: bool = type_constrain #: 实体的个数 self.ent_tol: int = 0 #: 关系的个数 self.rel_tol: int = 0 #: 测试集三元组的个数 self.test_tol: int = 0 #: 验证集三元组的个数 self.valid_tol: int = 0 # 读入数据 self.read()
[文档] def read(self): """利用 `pybind11 <https://github.com/pybind/pybind11>`__ 让底层 C++ 模块读取数据集中的数据""" print("Start reading validation and testing data...") base.set_in_path(self.in_path) base.set_ent_path(self.ent_file) base.set_rel_path(self.rel_file) base.set_train_path(self.train_file) base.set_valid_path(self.valid_file) base.set_test_path(self.test_file) base.rand_reset() base.read_test_files() if self.type_constrain: construct_type_constrain( in_path=self.in_path, train_file=self.train_file, valid_file=self.valid_file, test_file=self.test_file) base.read_type_files() print("Validation and testing data read completed.\n") self.ent_tol = base.get_entity_total() self.rel_tol = base.get_relation_total() self.test_tol = base.get_test_total() self.valid_tol = base.get_valid_total() # 利用 np.zeros 分配内存 self.test_h = np.zeros(self.ent_tol, dtype=np.int64) self.test_t = np.zeros(self.ent_tol, dtype=np.int64) self.test_r = np.zeros(self.ent_tol, dtype=np.int64) self.test_pos_h = np.zeros(self.test_tol, dtype=np.int64) self.test_pos_t = np.zeros(self.test_tol, dtype=np.int64) self.test_pos_r = np.zeros(self.test_tol, dtype=np.int64) self.test_neg_h = np.zeros(self.test_tol, dtype=np.int64) self.test_neg_t = np.zeros(self.test_tol, dtype=np.int64) self.test_neg_r = np.zeros(self.test_tol, dtype=np.int64)
# 为链接预测进行采样数据
[文档] def sampling(self) -> dict[str, typing.Union[np.ndarray, str]]: """为链接预测进行采样数据,为给定的正三元组,用所有实体依次替换头尾实体得到 2 * :py:attr:`ent_tol` 个三元组。 :returns: 对于一个正三元组生成的所有可能破化的三元组 :rtype: dict[str, typing.Union[np.ndarray, str]] """ res = [] base.get_head_batch(self.test_h, self.test_t, self.test_r, self.sampling_mode) res.append({ "batch_h": self.test_h.copy(), "batch_t": self.test_t[:1].copy(), "batch_r": self.test_r[:1].copy(), "mode": "head_batch" }) base.get_tail_batch(self.test_h, self.test_t, self.test_r, self.sampling_mode) res.append({ "batch_h": self.test_h[:1], "batch_t": self.test_t, "batch_r": self.test_r[:1], "mode": "tail_batch" }) return res
[文档] def get_ent_tol(self) -> int: """返回 :py:attr:`ent_tol` :returns: :py:attr:`ent_tol` :rtype: int """ return self.ent_tol
[文档] def get_rel_tol(self) -> int: """返回 :py:attr:`rel_tol` :returns: :py:attr:`rel_tol` :rtype: int """ return self.rel_tol
[文档] def get_test_tol(self) -> int: """返回 :py:attr:`test_tol` :returns: :py:attr:`test_tol` :rtype: int """ return self.test_tol
[文档] def get_valid_tol(self) -> int: """返回 :py:attr:`test_tol` :returns: :py:attr:`test_tol` :rtype: int """ return self.valid_tol
[文档] def set_sampling_mode(self, sampling_mode: str): """设置 :py:attr:`sampling_mode` :param sampling_mode: 数据采样模式,``link_test`` 和 ``link_valid`` 分别表示为链接预测进行测试集和验证集的负采样 :type sampling_mode: str """ self.sampling_mode = sampling_mode
[文档] def __iter__(self) -> TestDataSampler: """迭代器函数 :py:meth:`iterator.__iter__`, 根据 :py:attr:`sampling_mode` 决定是评估验证集还是测试集。""" if self.sampling_mode == "link_test": base.init_test() return TestDataSampler(self.test_tol, self.sampling) elif self.sampling_mode == "link_valid": base.init_test() return TestDataSampler(self.valid_tol, self.sampling) else: raise ValueError("pybind11_ke.data.TestDataLoader.sampling can only be a link_test or link_valid.")
[文档] def __len__(self) -> int: """len() 要求 :py:meth:`object.__len__` :returns: :py:attr:`test_tol` 或 :py:attr:`valid_tol` :rtype: int """ if self.sampling_mode == "link_test": return self.test_tol elif self.sampling_mode == "link_valid": return self.valid_tol
[文档]def get_test_data_loader_hpo_config() -> dict[str, dict[str, typing.Any]]: """返回 :py:class:`TestDataLoader` 的默认超参数优化配置。 默认配置为:: parameters_dict = { 'valid_file': { 'value': 'valid2id.txt' }, 'test_file': { 'value': 'test2id.txt' }, 'type_constrain': { 'value': True } } :returns: :py:class:`TestDataLoader` 的默认超参数优化配置 :rtype: dict[str, dict[str, typing.Any]] """ parameters_dict = { 'valid_file': { 'value': 'valid2id.txt' }, 'test_file': { 'value': 'test2id.txt' }, 'type_constrain': { 'value': True } } return parameters_dict

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs