Shortcuts

pybind11_ke.config.Tester 源代码

# coding:utf-8
#
# pybind11_ke/config/Tester.py
#
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 3, 2023
#
# 该脚本定义了验证模型类.

"""
Tester - 验证模型类,内部使用 ``tqmn`` 实现进度条。
"""

import base
import torch
import typing
import numpy as np
from tqdm import tqdm
from ..data import TestDataLoader, GraphDataLoader
from ..module.model import Model

[文档]class Tester(object): """ 主要用于 KGE 模型的评估。 例子:: from pybind11_ke.config import Trainer, Tester # test the model transe.load_checkpoint('../checkpoint/transe.ckpt') tester = Tester(model = transe, data_loader = test_dataloader, use_gpu = True) tester.run_link_prediction() """
[文档] def __init__( self, model: typing.Union[Model, None] = None, data_loader: TestDataLoader | GraphDataLoader | None = None, sampling_mode: str = 'link_test', use_gpu: bool = True, device: str = "cuda:0"): """创建 Tester 对象。 :param model: KGE 模型 :type model: :py:class:`pybind11_ke.module.model.Model` :param data_loader: TestDataLoader or GraphDataLoader :type data_loader: :py:class:`pybind11_ke.data.TestDataLoader` or :py:class:`pybind11_ke.data.GraphDataLoader` :param sampling_mode: :py:class:`pybind11_ke.data.TestDataLoader` 负采样的方式:``link_test`` or ``link_valid`` :type sampling_mode: str :param use_gpu: 是否使用 gpu :type use_gpu: bool :param device: 使用哪个 gpu :type device: str """ #: KGE 模型,即 :py:class:`pybind11_ke.module.model.Model` self.model: typing.Union[Model, None] = model #: :py:class:`pybind11_ke.data.TestDataLoader` or :py:class:`pybind11_ke.data.GraphDataLoader` self.data_loader: TestDataLoader | GraphDataLoader | None = data_loader #: :py:class:`pybind11_ke.data.TestDataLoader` 负采样的方式:``link_test`` or ``link_valid`` self.sampling_mode: str = sampling_mode #: 是否使用 gpu self.use_gpu: bool = use_gpu #: gpu,利用 ``device`` 构造的 :py:class:`torch.device` 对象 self.device: torch.device = torch.device(device) if self.use_gpu: self.model.cuda(device = self.device)
[文档] def to_var( self, x: np.ndarray, use_gpu: bool) -> torch.Tensor: """根据 ``use_gpu`` 返回 ``x`` 的张量 :param x: 数据 :type x: numpy.ndarray :param use_gpu: 是否使用 gpu :type use_gpu: bool :returns: 张量 :rtype: torch.Tensor """ if use_gpu: return torch.from_numpy(x).to(self.device) else: return torch.from_numpy(x)
[文档] def test_one_step( self, data: dict[str, typing.Union[np.ndarray, str]]) -> np.ndarray: """根据 :py:attr:`data_loader` 生成的 1 批次(batch) ``data`` 将模型验证 1 步。 :param data: :py:attr:`data_loader` 利用 :py:meth:`pybind11_ke.data.TestDataLoader.sampling` 函数生成的数据 :type data: dict[str, typing.Union[np.ndarray, str]] :returns: 三元组的得分 :rtype: numpy.ndarray """ return self.model.predict({ 'batch_h': self.to_var(data['batch_h'], self.use_gpu), 'batch_t': self.to_var(data['batch_t'], self.use_gpu), 'batch_r': self.to_var(data['batch_r'], self.use_gpu), 'mode': data['mode'] })
[文档] def set_sampling_mode(self, sampling_mode: str): """设置 :py:attr:`sampling_mode` :param sampling_mode: 数据采样模式,``link_test`` 和 ``link_valid`` 分别表示为链接预测进行测试集和验证集的负采样 :type sampling_mode: str """ self.sampling_mode = sampling_mode
[文档]def get_tester_hpo_config() -> dict[str, dict[str, typing.Any]]: """返回 :py:class:`Tester` 的默认超参数优化配置。 默认配置为:: parameters_dict = { 'tester': { 'value': 'Tester' }, 'use_gpu': { 'value': True }, 'device': { 'value': 'cuda:0' }, } :returns: :py:class:`Tester` 的默认超参数优化配置 :rtype: dict[str, dict[str, typing.Any]] """ parameters_dict = { 'tester': { 'value': 'Tester' }, 'use_gpu': { 'value': True }, 'device': { 'value': 'cuda:0' }, } return parameters_dict

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs