Shortcuts

pybind11_ke.data.GraphTestSampler 源代码

# coding:utf-8
#
# pybind11_ke/data/GraphTestSampler.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 17, 2024
#
# R-GCN 的测试数据采样器.

"""
GraphTestSampler - R-GCN 的测试数据采样器。
"""

import os
import dgl
import torch
import typing
import numpy as np
from .GraphSampler import GraphSampler
from .CompGCNSampler import CompGCNSampler
from collections import defaultdict as ddict

[文档]class GraphTestSampler(object): """``R-GCN`` :cite:`R-GCN` 的测试数据采样器。 例子:: from pybind11_ke.data import GraphTestSampler, CompGCNTestSampler from torch.utils.data import DataLoader #: 测试数据采样器 test_sampler: typing.Union[typing.Type[GraphTestSampler], typing.Type[CompGCNTestSampler]] = test_sampler( sampler=train_sampler, valid_file=valid_file, test_file=test_file, ) #: 验证集三元组 data_val: list[tuple[int, int, int]] = test_sampler.get_valid() #: 测试集三元组 data_test: list[tuple[int, int, int]] = test_sampler.get_test() val_dataloader = DataLoader( data_val, shuffle=False, batch_size=test_batch_size, num_workers=num_workers, pin_memory=True, collate_fn=test_sampler.sampling, ) test_dataloader = DataLoader( data_test, shuffle=False, batch_size=test_batch_size, num_workers=num_workers, pin_memory=True, collate_fn=test_sampler.sampling, ) """
[文档] def __init__( self, sampler: typing.Union[GraphSampler, CompGCNSampler], valid_file: str = "valid2id.txt", test_file: str = "test2id.txt"): """创建 GraphTestSampler 对象。 :param sampler: 训练数据采样器。 :type sampler: typing.Union[GraphSampler, CompGCNSampler] :param valid_file: valid2id.txt :type valid_file: str :param test_file: test2id.txt :type test_file: str """ #: 训练数据采样器 self.sampler: typing.Union[GraphSampler, CompGCNSampler] = sampler #: 实体的个数 self.ent_tol: int = sampler.ent_tol #: 训练集三元组 self.triples: list[tuple[int, int, int]] = self.sampler.t_triples if isinstance(self.sampler, CompGCNSampler) else self.sampler.train_triples #: 幂 self.power: float = -1 #: valid2id.txt self.valid_file: str = valid_file #: test2id.txt self.test_file: str = test_file #: 验证集三元组的个数 self.valid_tol: int = 0 #: 测试集三元组的个数 self.test_tol: int = 0 #: 验证集三元组 self.valid_triples: list[tuple[int, int, int]] = [] #: 测试集三元组 self.test_triples: list[tuple[int, int, int]] = [] #: 知识图谱所有三元组 self.all_true_triples: set[tuple[int, int, int]] = set() self.get_valid_test_triples_id() self.add_valid_test_reverse_triples() #: 知识图谱中所有 h-r 对对应的 t 集合 self.hr2t_all: ddict[set] = ddict(set) #: 知识图谱中所有 r-t 对对应的 h 集合 self.rt2h_all: ddict[set] = ddict(set) self.get_hr2t_rt2h_from_all()
[文档] def get_valid_test_triples_id(self): """读取 :py:attr:`valid_file` 文件和 :py:attr:`test_file` 文件。""" with open(os.path.join(self.sampler.in_path, self.valid_file)) as f: self.valid_tol = (int)(f.readline()) for line in f: h, t, r = line.strip().split() self.valid_triples.append((int(h), int(r), int(t))) with open(os.path.join(self.sampler.in_path, self.test_file)) as f: self.test_tol = (int)(f.readline()) for line in f: h, t, r = line.strip().split() self.test_triples.append((int(h), int(r), int(t)))
[文档] def add_valid_test_reverse_triples(self): """对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。""" tol = int(self.sampler.rel_tol / 2) with open(os.path.join(self.sampler.in_path, self.valid_file)) as f: f.readline() for line in f: h, t, r = line.strip().split() self.valid_triples.append( (int(t), int(r) + tol, int(h)) ) with open(os.path.join(self.sampler.in_path, self.test_file)) as f: f.readline() for line in f: h, t, r = line.strip().split() self.test_triples.append( (int(t), int(r) + tol, int(h)) ) self.all_true_triples = set( self.triples + self.valid_triples + self.test_triples )
[文档] def get_valid(self) -> list[tuple[int, int, int]]: """ 返回验证集三元组。 :returns: :py:attr:`valid_triples` :rtype: list[tuple[int, int, int]] """ return self.valid_triples
[文档] def get_test(self) -> list[tuple[int, int, int]]: """ 返回测试集三元组。 :returns: :py:attr:`test_triples` :rtype: list[tuple[int, int, int]] """ return self.test_triples
[文档] def get_all_true_triples(self) -> set[tuple[int, int, int]]: """ 返回知识图谱所有三元组。 :returns: :py:attr:`all_true_triples` :rtype: set[tuple[int, int, int]] """ return self.all_true_triples
[文档] def get_hr2t_rt2h_from_all(self): """获得 :py:attr:`hr2t_all` 和 :py:attr:`rt2h_all` 。""" for h, r, t in self.all_true_triples: self.hr2t_all[(h, r)].add(t) self.rt2h_all[(r, t)].add(h) for h, r in self.hr2t_all: self.hr2t_all[(h, r)] = torch.tensor(list(self.hr2t_all[(h, r)])) for r, t in self.rt2h_all: self.rt2h_all[(r, t)] = torch.tensor(list(self.rt2h_all[(r, t)]))
[文档] def sampling( self, data: list[tuple[int, int, int]]) -> dict[str, typing.Union[dgl.DGLGraph , torch.Tensor]]: """``R-GCN`` :cite:`R-GCN` 的测试数据采样函数。 :param data: 测试的正确三元组 :type data: list[tuple[int, int, int]] :returns: ``R-GCN`` :cite:`R-GCN` 的测试数据 :rtype: dict[str, typing.Union[dgl.DGLGraph , torch.Tensor]] """ batch_data = {} head_label = torch.zeros(len(data), self.ent_tol) tail_label = torch.zeros(len(data), self.ent_tol) for idx, triple in enumerate(data): head, rel, tail = triple head_label[idx][self.rt2h_all[(rel, tail)]] = 1.0 tail_label[idx][self.hr2t_all[(head, rel)]] = 1.0 batch_data["positive_sample"] = torch.tensor(data) batch_data["head_label"] = head_label batch_data["tail_label"] = tail_label graph, rela, norm = self.sampler.build_graph(self.ent_tol, np.array(self.triples).transpose(), self.power) batch_data["graph"] = graph batch_data["rela"] = rela batch_data["norm"] = norm batch_data["entity"] = torch.arange(0, self.ent_tol, dtype=torch.long).view(-1,1) return batch_data

Docs

Access comprehensive developer documentation for Pybind11-OpenKE

View Docs