pybind11_ke.config.TrainerAccelerator 源代码
# coding:utf-8
#
# pybind11_ke/config/TrainerAccelerator.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Apr 12, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Apr 27, 2024
#
# 该脚本定义了并行训练循环函数.
"""
利用 accelerate 实现并行训练。
"""
from typing import Any, List
from accelerate import Accelerator
[文档]def accelerator_prepare(*args: List[Any]) -> List[Any]:
"""
由于分布式并行依赖于 `accelerate <https://github.com/huggingface/accelerate>`_ ,因此,需要利用 Accelerator 为分布式训练准备对象。
例子::
dataloader, model, accelerator = accelerator_prepare(
dataloader,
model
)
:param args: :py:class:`pybind11_ke.data.KGEDataLoader` 和 :py:class:`pybind11_ke.module.strategy.Strategy` 。
:type args: typing.List[typing.Any]
:returns: 包装好的对象列表和 Accelerator() 对象。
:rtype: typing.List[typing.Any]
"""
accelerator = Accelerator()
result = accelerator.prepare(*args)
result = list(result)
result.append(accelerator)
return result