NegativeSampling¶
- class pybind11_ke.module.strategy.NegativeSampling(*args: Any, **kwargs: Any)[源代码]¶
将模型和损失函数封装到一起,方便模型训练。
例子:
from pybind11_ke.config import Trainer from pybind11_ke.module.model import TransE from pybind11_ke.module.loss import MarginLoss from pybind11_ke.module.strategy import NegativeSampling # define the model transe = TransE( ent_tol = train_dataloader.get_ent_tol(), rel_tol = train_dataloader.get_rel_tol(), dim = 50, p_norm = 1, norm_flag = True) # define the loss function model = NegativeSampling( model = transe, loss = MarginLoss(margin = 1.0), batch_size = train_dataloader.get_batch_size() ) # train the model trainer = Trainer(model = model, data_loader = train_dataloader, train_times = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1', tester = tester, test = True, valid_interval = 100, log_interval = 100, save_interval = 100, save_path = '../../checkpoint/transe.pth') trainer.run()
- __init__(model: Model | None = None, loss: Loss | None = None, batch_size: int = 256, regul_rate: float = 0.0, l3_regul_rate: float = 0.0)[源代码]¶
创建 NegativeSampling 对象。
- 参数:
model (
pybind11_ke.module.model.Model) – KGE 模型loss (
pybind11_ke.module.loss.Loss) – 损失函数。batch_size (int) – batch size
regul_rate (float) – 权重衰减系数
l3_regul_rate (float) – l3 正则化系数
- __weakref__¶
list of weak references to the object (if defined)
- _get_negative_score(score: torch.Tensor) torch.Tensor[源代码]¶
获得负样本的得分,由于底层 C++ 处理模块的原因, 所以正样本的得分处于前 batch size 位置,负样本处于正样本后面。
- 参数:
score – 所有样本的得分。
- 返回:
负样本的得分
- 返回类型:
- _get_positive_score(score: torch.Tensor) torch.Tensor[源代码]¶
获得正样本的得分,由于底层 C++ 处理模块的原因, 所以正样本的得分处于前 batch size 位置。
- 参数:
score – 所有样本的得分。
- 返回:
正样本的得分
- 返回类型:
- forward(data: dict[str, Union[torch.Tensor, str]]) torch.Tensor[源代码]¶
计算最后的损失值。定义每次调用时执行的计算。
torch.nn.Module子类必须重写torch.nn.Module.forward()。- 参数:
data (dict[str, Union[torch.Tensor,str]]) – 数据
- 返回:
损失值
- 返回类型:
- get_parameters(mode: str = 'numpy', param_dict: dict[str, Any] | None = None) dict[str, numpy.ndarray] | dict[str, list] | dict[str, torch.Tensor]¶
获得模型权重。
- loss: Loss¶
损失函数,即
pybind11_ke.module.loss.Loss
- model: Model¶
KGE 模型,即
pybind11_ke.module.model.Model
- pi_const: torch.nn.parameter.Parameter¶
常数 pi
- zero_const: torch.nn.parameter.Parameter¶
常数 0