TrainDataLoader¶
- class pybind11_ke.data.TrainDataLoader(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, nbatches: int | None = None, threads: int = 8, sampling_mode: str = 'normal', bern: bool = True, neg_ent: int = 1, neg_rel: int = 0)[源代码]¶
主要从底层 C++ 模块获得数据用于 KGE 模型的训练。
例子:
from pybind11_ke.config import Trainer from pybind11_ke.module.model import TransE from pybind11_ke.module.loss import MarginLoss from pybind11_ke.module.strategy import NegativeSampling from pybind11_ke.data import TrainDataLoader # dataloader for training train_dataloader = TrainDataLoader( in_path = "../../benchmarks/FB15K/", nbatches = 200, threads = 8, sampling_mode = "normal", bern = False, neg_ent = 25, neg_rel = 0) # define the model transe = TransE( ent_tol = train_dataloader.get_ent_tol(), rel_tol = train_dataloader.get_rel_tol(), dim = 50, p_norm = 1, norm_flag = True) # define the loss function model = NegativeSampling( model = transe, loss = MarginLoss(margin = 1.0), batch_size = train_dataloader.get_batch_size() ) # train the model trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1', tester = tester, test = True, valid_interval = 100, log_interval = 100, save_interval = 100, save_path = '../../checkpoint/transe.pth') trainer.run()
- __init__(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, nbatches: int | None = None, threads: int = 8, sampling_mode: str = 'normal', bern: bool = True, neg_ent: int = 1, neg_rel: int = 0)[源代码]¶
创建 TrainDataLoader 对象。
- 参数:
in_path (str) – 数据集目录
ent_file (str) – entity2id.txt
rel_file (str) – relation2id.txt
train_file (str) – train2id.txt
batch_size (int) – batch_size 可以根据 nbatches 计算得出,两者不可以同时不提供;同时指定时 batch_size 优先级更高
nbatches (int) – nbatches 可以根据 batch_size 计算得出,两者不可以同时不提供;同时指定时 batch_size 优先级更高
threads (int) – 底层 C++ 数据处理所需要的线程数
sampling_mode (str) – 数据采样模式,
normal表示正常负采样,cross表示交替替换 head 和 tail 进行负采样bern (bool) – 是否使用 TransH 提出的负采样方法进行负采样
neg_ent (int) – 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)
neg_rel (int) – 对于每一个正三元组, 构建的负三元组的个数, 替换 relation
- __iter__() TrainDataSampler[源代码]¶
迭代器函数
iterator.__iter__(), 根据sampling_mode选择返回sampling()和cross_sampling()
- __weakref__¶
list of weak references to the object (if defined)
- get_batch_size() int[源代码]¶
返回
batch_size- 返回:
- 返回类型: