UniSampler¶
- class pybind11_ke.data.UniSampler(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, neg_ent: int = 1)[源代码]¶
平移模型和语义匹配模型的训练集普通的数据采样器(均值分布)。
- __init__(in_path: str = './', ent_file: str = 'entity2id.txt', rel_file: str = 'relation2id.txt', train_file: str = 'train2id.txt', batch_size: int | None = None, neg_ent: int = 1)[源代码]¶
创建 UniSampler 对象。
- __weakref__¶
list of weak references to the object (if defined)
- corrupt_head(t: int, r: int, num_max: int = 1) numpy.ndarray¶
替换头实体构建负三元组。
- 参数:
- 返回:
负三元组的头实体列表
- 返回类型:
- corrupt_tail(h: int, r: int, num_max: int = 1) numpy.ndarray¶
替换尾实体构建负三元组。
- 参数:
- 返回:
负三元组的尾实体列表
- 返回类型:
- get_hr2t_rt2h_from_train()¶
获得
hr2t_train和rt2h_train。
- get_train_triples_id()¶
读取
train_file文件。
- head_batch(t: int, r: int, neg_size: int | None = None) numpy.ndarray[源代码]¶
替换头实体构建负三元组。
- 参数:
- 返回:
负三元组中的头实体列表
- 返回类型:
- hr2t_train: collections.defaultdict[set]¶
训练集中所有 h-r 对对应的 t 集合
- rt2h_train: collections.defaultdict[set]¶
训练集中所有 r-t 对对应的 h 集合
- sampling(pos_triples: list[tuple[int, int, int]]) dict[str, Union[str, torch.Tensor]][源代码]¶
平移模型和语义匹配模型的训练集普通的数据采样函数(均匀分布)。