Shortcuts

TrainDataLoader

class pybind11_ke.data.TrainDataLoader(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, nbatches: int | None = None, threads: int = 8, sampling_mode: str = 'normal', bern: bool = True, neg_ent: int = 1, neg_rel: int = 0)[源代码]

主要从底层 C++ 模块获得数据用于 KGE 模型的训练。

例子:

from pybind11_ke.config import Trainer
from pybind11_ke.module.model import TransE
from pybind11_ke.module.loss import MarginLoss
from pybind11_ke.module.strategy import NegativeSampling
from pybind11_ke.data import TrainDataLoader

# dataloader for training
train_dataloader = TrainDataLoader(
        in_path = "../../benchmarks/FB15K/", 
        nbatches = 200,
        threads = 8, 
        sampling_mode = "normal", 
        bern = False,  
        neg_ent = 25,
        neg_rel = 0)

# define the model
transe = TransE(
        ent_tol = train_dataloader.get_ent_tol(),
        rel_tol = train_dataloader.get_rel_tol(),
        dim = 50, 
        p_norm = 1, 
        norm_flag = True)

# define the loss function
model = NegativeSampling(
        model = transe, 
        loss = MarginLoss(margin = 1.0),
        batch_size = train_dataloader.get_batch_size()
)

# train the model
trainer = Trainer(model = model, data_loader = train_dataloader,
        train_times = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1',
        tester = tester, test = True, valid_interval = 100,
        log_interval = 100, save_interval = 100, save_path = '../../checkpoint/transe.pth')
trainer.run()
__init__(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, nbatches: int | None = None, threads: int = 8, sampling_mode: str = 'normal', bern: bool = True, neg_ent: int = 1, neg_rel: int = 0)[源代码]

创建 TrainDataLoader 对象。

参数:
  • in_path (str) – 数据集目录

  • ent_file (str) – entity2id.txt

  • rel_file (str) – relation2id.txt

  • train_file (str) – train2id.txt

  • batch_size (int) – batch_size 可以根据 nbatches 计算得出,两者不可以同时不提供;同时指定时 batch_size 优先级更高

  • nbatches (int) – nbatches 可以根据 batch_size 计算得出,两者不可以同时不提供;同时指定时 batch_size 优先级更高

  • threads (int) – 底层 C++ 数据处理所需要的线程数

  • sampling_mode (str) – 数据采样模式,normal 表示正常负采样,cross 表示交替替换 head 和 tail 进行负采样

  • bern (bool) – 是否使用 TransH 提出的负采样方法进行负采样

  • neg_ent (int) – 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)

  • neg_rel (int) – 对于每一个正三元组, 构建的负三元组的个数, 替换 relation

__iter__() TrainDataSampler[源代码]

迭代器函数 iterator.__iter__(), 根据 sampling_mode 选择返回 sampling()cross_sampling()

__len__() int[源代码]

len() 要求 object.__len__()

返回:

nbatches

返回类型:

int

__weakref__

list of weak references to the object (if defined)

batch_size: int

batch_size 可以根据 nbatches 计算得出,两者不可以同时不提供;同时指定时 batch_size 优先级更高

bern: bool

是否使用 TransH 提出的负采样方法进行负采样

cross_sampling() dict[str, Union[numpy.ndarray, str]][源代码]

交替替换 head 和 tail 进行负采样, 生成 1 batch 数据

返回:

1 batch 数据

返回类型:

dict[str, Union[np.ndarray, str]]

ent_file: str

entity2id.txt

ent_tol: int

实体的个数

get_batch_size() int[源代码]

返回 batch_size

返回:

batch_size

返回类型:

int

get_ent_tol() int[源代码]

返回 ent_tol

返回:

ent_tol

返回类型:

int

get_rel_tol() int[源代码]

返回 rel_tol

返回:

rel_tol

返回类型:

int

get_train_tot() int[源代码]

返回 train_tot

返回:

train_tot

返回类型:

int

in_path: str

数据集目录

nbatches: int

nbatches 可以根据 batch_size 计算得出,两者不可以同时不提供;同时指定时 batch_size 优先级更高

neg_ent: int

对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)

neg_rel: int

对于每一个正三元组, 构建的负三元组的个数, 替换 relation

read()[源代码]

利用 pybind11 让底层 C++ 模块读取数据集中的数据

rel_file: str

relation2id.txt

rel_tol: int

关系的个数

sampling() dict[str, Union[numpy.ndarray, str]][源代码]

正常采样1 batch 数据,即 normal

返回:

1 batch 数据

返回类型:

dict[str, Union[np.ndarray, str]]

sampling_head() dict[str, Union[numpy.ndarray, str]][源代码]

只替换 head 进行负采样, 生成 1 batch 数据

返回:

1 batch 数据

返回类型:

dict[str, Union[np.ndarray, str]]

sampling_mode: str

数据采样模式,normal 表示正常负采样,cross 表示交替替换 head 和 tail 进行负采样

sampling_tail() dict[str, Union[numpy.ndarray, str]][源代码]

只替换 tail 进行负采样, 生成 1 batch 数据

返回:

1 batch 数据

返回类型:

dict[str, Union[np.ndarray, str]]

threads: int

底层 C++ 数据处理所需要的线程数

train_file: str

train2id.txt

train_tot: int

训练集三元组的个数

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs