写点什么

OneFlow 学习笔记:从 OpExprInterpreter 到 OpKernel

作者:OneFlow
  • 2022 年 4 月 28 日
  • 本文字数:13752 字

    阅读完需:约 45 分钟

OneFlow学习笔记:从OpExprInterpreter到OpKernel

撰文|月踏

更新|赵露阳


前文《OneFlow学习笔记:从Functor到OpExprInterpreter》讲了 OpExprInterpreter 的相关细节,再往下就是 OneFlow 中的虚拟机,它负责在 eager 模式下把指令(即 op,在 vm 中称为指令)调度到具体的 OpKernel 上来执行。

1、Global 简介

先看一个特殊的类 Global,定义在 oneflow/core/common/global.h,这个类很简单,但是对于整个系统来说很重要,主要的几个接口如下:

template<typename T, typename Kind = void>class Global final { public:  // 获取创建过的对象  static T* Get() { ... }  // 创建对象  static void SetAllocated(T* val) { ... }  template<typename... Args>  static T* New(Args&&... args) { ... }  // 释放对象  static void Delete() { ... }  ...};
复制代码

这是一个可以根据指定顺序来创建全局单例对象的类,主要用在系统的初始化中,这样对于一些全局的对象在初始化的时候创建好,后续整个系统的各个模块就都可以使用了。

2、系统初始化过程

再继续看系统的初始化流程,首先在 python/oneflow/__init__.py+217 中可以找到下面这句话:

__oneflow_global_unique_env = env_util.GetEnv()
复制代码

GetEnv()方法在 python/oneflow/framework/env_util.py 中定义,其返回一个 EnvHolder 的 Python 对象,此对象初始化时,通过 self._env_cxt = create_env()创建了 OneFlow 运行时所需要的环境上下文:

class EnvHolder(object):    def __init__(self):        if not HasAllMultiClientEnvVars():            SetDefaultMultiClientEnvVars()        self._env_cxt = create_env()    ...

def create_env(): """create environment

Returns: Env: [description] """ global default_env_proto assert len(default_env_proto.machine) > 0 CompleteEnvProto(default_env_proto) if default_env_proto.ctrl_bootstrap_conf.world_size > 1: check_non_localhost_proxy_and_print_warning() return c_api_util.GetEnvContext(default_env_proto)
复制代码

create_env()中,首先会通过 CompleteEnvProto 创建默认的 env_proto 对象,然后根据此 env proto 对象创建 oneflow 所需要的环境上下文 env_ctx。

这里面和初始化相关的主线是 GetEnvContext,其定位位于 python/oneflow/framework/c_api_util.py+45:

def GetEnvContext(env_proto):    assert type(env_proto) is env_pb2.EnvProto    env_proto_str = text_format.MessageToString(env_proto)    env_ctx = oneflow._oneflow_internal.EnvContext(env_proto_str)    return env_ctx
复制代码

这个 EnvContext 是 oneflow 内部导出的 c api,其定义位于:oneflow/api/python/env/env.cpp:L46。

其作用即初始化一个单例——env 作用域对象 EnvGlobalObjectsScope,并在其构造之初,通过 oneflow/core/job/env_global_objects_scope.cpp:L153 的 EnvGlobalObjectsScope::Init()方法初始化一些系统需要的其他全局单例对象/配置:

Maybe<void> EnvGlobalObjectsScope::Init(const EnvProto& env_proto) {  ...  Global<EnvDesc>::New(env_proto);  Global<ProcessCtx>::New();  ...#ifdef WITH_CUDA  Global<EagerNcclCommMgr>::New();  Global<CudnnConvAlgoCache>::New();  Global<embedding::EmbeddingManager>::New();#endif Global<vm::VirtualMachineScope>::New(Global<ResourceDesc, ForSession>::Get()->resource());  Global<EagerJobBuildAndInferCtxMgr>::New();  ...   return Maybe<void>::Ok();}
复制代码

上面删去了很多代码,只展示了部分对象的创建,如:Globalvm::VirtualMachineScope::New。


它会创建一个 VirtualMachineScope 的单例对象,这个类的构造函数因此会被执行一次,如下所示:

VirtualMachineScope::VirtualMachineScope(const Resource& resource) {  Global<VirtualMachine>::New(resource, GlobalProcessCtx::Rank());}
复制代码

在这个构造函数里,又通过 Global 创建了一个 VirtualMachine 的单例对象,这是个很重要的单例对象,后面讲虚拟机时会用到它,所以先在这一节引出。

3、StreamType 和 InstructionType 的注册

还需要再看一部分和后面虚拟机非常相关的内容作为准备,它们是 StreamType 和 InstructionType 的注册,先看下面这段代码,位于 oneflow/core/eager/cpu_opkernel_instruction_type.cpp+34:

COMMAND(vm::RegisterInstructionType<CpuLocalCallOpKernelInstructionType>("cpu.LocalCallOpKernel"));
复制代码

COMMAND 是一个宏,位于 oneflow/core/common/util.h+115,它的实现很巧妙,利用了匿名空间来保证在源文件定义的变量只在源文件可见,用 CommandT 和__LINE__在源文件中定义了一个唯一名字的 struct,把注册语句放在它的构造函数中,然后再定义一个该 struct 的对象,其构造函数被自动执行的时候,注册语句也被执行:

#define COMMAND(...)                                                \  namespace {                                                       \  struct OF_PP_CAT(CommandT, __LINE__) {                            \    OF_PP_CAT(CommandT, __LINE__)() { __VA_ARGS__; }                \  };                                                                \  OF_PP_CAT(CommandT, __LINE__) OF_PP_CAT(g_command_var, __LINE__); \  }
复制代码

再看实际的注册语句,它的模板参数是 CpuLocalCallOpKernelInstructionType,定义在 oneflow/core/eager/cpu_opkernel_instruction_type.cpp+27,如下所示:

class CpuLocalCallOpKernelInstructionType final : public LocalCallOpKernelInstructionType { public:  CpuLocalCallOpKernelInstructionType() = default;  ~CpuLocalCallOpKernelInstructionType() override = default;

using stream_type = vm::CpuStreamType;};
复制代码

这段代码中的 stream_type 在下面会很有用,这段代码其实是把 CpuLocalCallOpKernelInstructionType 类和 vm::CpuStreamType 类建立了关联,再继续看 COMMAND 宏中的注册语句,单独摘出来如下所示:

vm::RegisterInstructionType<CpuLocalCallOpKernelInstructionType>("cpu.LocalCallOpKernel")
复制代码

RegisterInstructionType 是一个模板函数,定义位于 oneflow/core/vm/instruction_type.h+80:

template<typename T>void RegisterInstructionType(const std::string& instr_type_name) {  RegisterInstrTypeId<T>(instr_type_name, StaticGlobalStreamType<typename T::stream_type>());}
复制代码

以这里 COMMAND 的示例中对 CpuLocalCallOpKernelInstructionType 的注册为例,按行来看,注册函数 RegisterInstructionType 主要内容在:oneflow/core/vm/instruction_type.cpp+54:

void RegisterInstrTypeId(const std::string& instruction_name, const StreamType* stream_type,                         const InstructionType* instruction_type) {  InstrTypeId instr_type_id;  instr_type_id.__Init__(stream_type, instruction_type);  CHECK(InstrTypeId4InstructionName()->emplace(instruction_name, instr_type_id).second);}
复制代码

实际做了下面几件事(CpuLocalCallOpKernelInstructionType 的名字较长,为了方便表示,下面简称它为 T):

  • 初始化一个 InstrTypeId 对象,并调用其__Init__方法为其成员变量 stream_type_和 instruction_type_赋值,这里 stream_type 就是 T::stream_type,即 vm::CpuStreamType;instruction_type 即指向 T 的指令类型的指针对象。

  • 通过 InstrTypeId4InstructionName()方法拿到一个静态 HashMap<std::string, InstrTypeId> map 对象的指针。

  • 将 instruction_name("cpu.LocalCallOpKernel")作为 key,InstrTypeId 对象 instr_type_id 作为 value 插入这个 map 中。


4、虚拟机调度过程 1

前文《OneFlow学习笔记:从Functor到OpExprInterpreter》讲到了调用 PhysicalRun 之前的 mirror mode 和 eager mode 的大概流程,已经准备好了输入输出的 EagerBlobObject 以及一些 context 信息和相关的 device 信息,在调用 PhysicalRun 这个函数之后,就进入了虚拟机的部分。

4.1 放指令线程

PhysicalRun 接受一个 call-back function 作为参数,这个 call-back 函数中会调用 builder->LocalCallOpKernel 这个函数,并且以前面准备好的输入、输出、ctx、device 作为参数来执行,先来看 PhysicalRun 函数,它定义在 oneflow/core/framework/instructions_builder.cpp+595:

Maybe<void> PhysicalRun(const std::function<Maybe<void>(InstructionsBuilder*)>& Build) {  vm::InstructionMsgList instruction_list;  InstructionsBuilder instructions_builder(std::make_shared<vm::PhysicalIdGenerator>(),                                           &instruction_list);  JUST(Build(&instructions_builder));  JUST(vm::Run(instructions_builder.mut_instruction_list()));  return Maybe<void>::Ok();}
复制代码

这里的 Build 就是刚从传进来的 call-back 函数,整理出来再来加深一下印象:

[&](InstructionsBuilder* builder) -> Maybe<void> {    return builder->LocalCallOpKernel(kernel, input_eager_blob_objects, output_eager_blob_objects, ctx, op_device);}
复制代码

在 PhysicalRun 中,以 InstructionsBuilder 对象为参数来调用这个 call-back function,所以会执行 InstructionsBuilder 中的 LocalCallOpKernel 函数,这个函数位于 oneflow/core/framework/instructions_builder.cpp+347:

Maybe<void> InstructionsBuilder::LocalCallOpKernel(...) {  ...  auto phy_instr_operand = JUST(vm::LocalCallOpKernelPhyInstrOperand::New(      opkernel, input_eager_blob_objects, output_eager_blob_objects, consistent_tensor_infer_result,      ctx, *one::CurrentDevVmDepObjectConsumeMode()));  auto instruction = intrusive::make_shared<vm::InstructionMsg>(      Global<VirtualMachine>::Get()->mut_vm(), JUST(op_device->local_call_instruction_name()),      parallel_desc_sym, phy_instr_operand);  instruction_list_->EmplaceBack(std::move(instruction));  ...  return Maybe<void>::Ok();}
复制代码

这个函数逻辑大概是把输入 op 相关的信息打包成一个 vm::InstructionMsg 对象,然后放到 instruction_list_这个 list 中。

到这里前面的 PhysicalRun 中的 Build 部分就分析完了,继续看 Build 之后的逻辑 vm::Run,它主要是调了 oneflow/core/vm/vm_util.cpp+34 中的 Run 方法:

Maybe<void> Run(vm::InstructionMsgList* instr_msg_list) {  auto* virtual_machine = JUST(GlobalMaybe<VirtualMachine>());  JUST(virtual_machine->Receive(instr_msg_list));  return Maybe<void>::Ok();}
复制代码

这里通过 GlobalMaybe 来得到了在前面第一节 OneFlow 初始化中讲到的被创建好的 VirtualMachine 对象,这里调用了 VirtualMachine 中的 Receive 函数,位于 oneflow/core/vm/virtual_machine.cpp+204:

Maybe<bool> VirtualMachineEngine::Receive(    intrusive::shared_ptr<InstructionMsg>&& compute_instr_msg) {  InstructionMsgList instr_msg_list;  instr_msg_list.EmplaceBack(std::move(compute_instr_msg));  return Receive(&instr_msg_list);}
复制代码

这里的 vm_变量类型是 intrusive::shared_ptrvm::VirtualMachineEngine,在我们的示例中,会走到 else 分支,也就调用了 VirtualMachineEngine 的 Receive 函数,它位于 oneflow/core/vm/virtual_machine_engine.cpp+422,VirtualMachineEngine 是一个很大很复杂的类,这里我们不关注它的其它功能,只关注当前的流程,下面是 Receive 函数的代码:

Maybe<bool> VirtualMachineEngine::Receive(InstructionMsgList* compute_instr_msg_list) {  OF_PROFILER_RANGE_PUSH("vm:Receive");  INTRUSIVE_UNSAFE_FOR_EACH_PTR(compute_instr_msg, compute_instr_msg_list) {    OF_PROFILER_RANGE_PUSH(compute_instr_msg->DebugName());    OF_PROFILER_RANGE_POP();  }  bool old_list_empty = mut_pending_msg_list()->MoveFrom(compute_instr_msg_list);  OF_PROFILER_RANGE_POP();  return old_list_empty;}

Maybe<bool> VirtualMachineEngine::Receive( intrusive::shared_ptr<InstructionMsg>&& compute_instr_msg) { InstructionMsgList instr_msg_list; instr_msg_list.EmplaceBack(std::move(compute_instr_msg)); return Receive(&instr_msg_list);}
复制代码

从这里看到并没有指令被执行,唯一的一条线索是传进来的 compute_instr_msg_list 最终被放入了 mut_pending_msg_list()中,当前的线程只负责往队列里放指令,另外有线程会从队列里往外取指令来执行,所以继续搜下 mut_pending_msg_list()会在哪里被用到,可以搜到在 oneflow/core/vm/virtual_machine_engine.cpp+514 的 Schedule 函数中被调用,Schedule 又在 oneflow/core/vm/virtual_machine.cpp+291 中的 ScheduleLoop 函数中被调用,这就引入了使用指令的线程。

4.2 用指令线程

直接看 ScheduleLoop 线程函数被启动的地方,它在 VirtualMachine 的构造函数中作为一个线程函数被创建和启动,VirtualMachine 的构造函数位于 oneflow/core/vm/virtual_machine.cpp+114,如下所示:

VirtualMachine::VirtualMachine(const Resource& resource, int64_t this_machine_id)    : vm_threads_closed_(false) {  ...  std::function<void()> SchedulerInitializer;  GetSchedulerThreadInitializer(&SchedulerInitializer);  schedule_thread_ = std::thread(&VirtualMachine::ScheduleLoop, this, SchedulerInitializer);}
复制代码

从前面第一节讲的的 OneFlow 初始化流程中可知,在 OneFlow 初始化的时候创建一个 VirtualMachine 的全局对象,自然其构造函数会被调用,所以这个 VirtualMachine::ScheduleLoop 线程函数在那时就被启动了,继续看 ScheduleLoop 的内容,位于 oneflow/core/vm/virtual_machine.cpp+291:

void VirtualMachine::ScheduleLoop(const std::function<void()>& Initializer) {  ...  while (pending_notifier_.WaitAndClearNotifiedCnt() == kNotifierStatusSuccess) {    ...    do {      ...      do {        ...        do { vm->Schedule(schedule_ctx); } while (!vm->ThreadUnsafeEmpty());        vm->MoveToGarbageMsgListAndNotifyGC(schedule_ctx);      } while (++i < kNumSchedulingPerTimoutTest);    } while (MicrosecondsFrom(start) < kWorkingMicroseconds);  }  ...}
复制代码

这里面最重要的是 Schedule 函数的调用,位于 oneflow/core/vm/virtual_machine_engine.cpp+514,简化代码如下

void VirtualMachineEngine::Schedule() {  if (...) { ReleaseFinishedInstructions(); }  if (...) { TryRunBarrierInstruction(); }  if (...) { HandleLocalPending(); }  if (...) { DispatchAndPrescheduleInstructions(); }}
复制代码

这个函数里比较重要的两个函数是 HandleLocalPending 和 DispatchAndPrescheduleInstructions,先看 HandleLocalPending,位于 oneflow/core/vm/virtual_machine_engine.cpp+62,它的精简代码如下:

void VirtualMachineEngine::HandlePending() {  ...  InstructionMsgList pending_instr_msgs;  INTRUSIVE_FOR_EACH_PTR(instr_msg, &pending_instr_msgs) {    MakeInstructions(instr_msg, /*out*/ &new_instruction_list);  }  ...  INTRUSIVE_FOR_EACH_PTR(instruction, &new_instruction_list) {    ConsumeMirroredObjects(instruction);    if (likely(Dispatchable(instruction))) {      mut_ready_instruction_list()->PushBack(instruction);      new_instruction_list.Erase(instruction);    }  }}
复制代码

可见它的工作主要是通过 MakeInstructions 制作指令,然后把指令放入 list,再看 DispatchAndPrescheduleInstructions,它位于 oneflow/core/vm/virtual_machine_engine.cpp+320:

void VirtualMachineEngine::DispatchAndPrescheduleInstructions() {  ReadyInstructionList tmp_ready_instruction_list;  mut_ready_instruction_list()->MoveTo(&tmp_ready_instruction_list);  INTRUSIVE_FOR_EACH(instruction, &tmp_ready_instruction_list) {    ...    DispatchInstruction(instruction.Mutable());    ...  }  ...}
复制代码

这个函数的主要工作是调用了 DispatchInstruction,继续来看一下这个函数,位于 oneflow/core/vm/virtual_machine_engine.cpp+344:

void VirtualMachineEngine::DispatchInstruction(Instruction* instruction,                                               const ScheduleCtx& schedule_ctx) {  auto* stream = instruction->mut_stream();  stream->mut_running_instruction_list()->PushBack(instruction);  if (stream->active_stream_hook().empty()) { mut_active_stream_list()->PushBack(stream); }  const auto& stream_type = stream->stream_type();  if (OnSchedulerThread(stream_type)) {    stream_type.Run(instruction);  } else {    stream->mut_thread_ctx()->mut_pending_instruction_list()->PushBack(instruction);    schedule_ctx.OnWorkerLoadPending(stream->mut_thread_ctx());  }}
复制代码

从这个函数中可以看出,指令被 stream_type.Run 来执行了,这里打断一下,用下面一节内容来追一下这里的 stream_type 从哪来的。

5、指令中的 stream

从上面第四节的最后一段代码中,可以看到 stream_type 来自于 stream,stream 来自于 Instruction,本节来追一下 Instruction 中的 stream 是怎么来的。

以 mirror mode 为例,代码会首先进入 4.1 节讲过的 LocalCallOpKernel 函数执行,位于 oneflow/core/framework/instructions_builder.cpp+347:

Maybe<void> InstructionsBuilder::LocalCallOpKernel(..., Symbol<Device> op_device) {  ...  const auto& instruction_name = JUST(StreamRoleSwitch<GetCallInstructionName>(      stream->stream_role(), stream->device()->enum_type()));  auto instruction = intrusive::make_shared<vm::InstructionMsg>(      Global<VirtualMachine>::Get()->mut_vm(), instruction_name, parallel_desc_sym,      phy_instr_operand);  instruction_list_->EmplaceBack(std::move(instruction));  ...  return Maybe<void>::Ok();}
复制代码

这里主要是在创建指令 instruction 对象,创建完成后放入指令列表末尾。

这里先看一下 instruction_name 是怎么产生的,在 GetCallInstructionName 的结构体中维护着 stream_role、stream type 以及对应的指令名称 instruction_name 之间的映射关系,在 StreamRoleSwitch 模板中会转发至其 Case 方法,并最终返回 instruction_name 的字符串。

所以在我们的示例中会返回"cpu.LocalCallOpKernel",在第三节中的注册示例中,可以看到以这个字符串为 key,注册了 CpuLocalCallOpKernelInstructionType 这个类,它关联了 vm::CpuStreamType 类型,这些信息在后面都会用到。

再看 InstructionMsg,它的定义位于 oneflow/core/vm/instruction.h+39:

class InstructionMsg final : public intrusive::Base {  ...  InstrTypeId instr_type_id_;  std::string instr_type_name_;  ...  Stream* phy_instr_stream_;};
复制代码

InstructionMsg 持有的 InstrTypeId、Stream 指针这两个成员和我们要追的 stream 的线索最相关,我们只需要关注这两个成员就好,在前面调用 intrusive::make_sharedvm::InstructionMsg(...)的时候,根据 intrusive::make_shared 的实现,会调用到 InstructionMsg 的下面这个__Init__函数,位于 oneflow/core/vm/instruction.cpp+42:

void InstructionMsg::__Init__(VirtualMachineEngine* vm, const std::string& instr_type_name,                              const std::shared_ptr<const ParallelDesc>& phy_instr_parallel_desc,                              const std::shared_ptr<PhyInstrOperand>& phy_instr_operand) {  __Init__();  if (likely(phy_instr_parallel_desc)) {    int device_id = phy_instr_parallel_desc->parallel_id2device_id().at(0);    vm->GetCachedInstrTypeIdAndPhyInstrStream(instr_type_name, device_id, mut_instr_type_id(),                                              &phy_instr_stream_);  }  ...}
复制代码

instr_type_id_和 phy_instr_stream_的赋值就是在上面代码中的 GetCachedInstrTypeIdAndPhyInstrStream 函数调用中完成的,定义位于 oneflow/core/vm/virtual_machine_engine.cpp+383:

void VirtualMachineEngine::GetCachedInstrTypeIdAndPhyInstrStream(const std::string& instr_type_name,                                                                 int device_id,                                                                 InstrTypeId* instr_type_id,                                                                 Stream** stream) {  auto* cache = &instr_type_name2rt_instr_type_id_;  auto iter = cache->find(instr_type_name);  if (unlikely(iter == cache->end())) {    const auto& instr_type_id_val = LookupInstrTypeId(instr_type_name);    const auto* stream_type = &instr_type_id_val.stream_type();    auto* stream_rt_desc = this->mut_stream_type2stream_rt_desc()->FindPtr(stream_type);    iter = cache->emplace(instr_type_name, RtInstrTypeId(instr_type_id_val, stream_rt_desc)).first;  }  instr_type_id->CopyFrom(iter->second.instr_type_id());  *stream = iter->second.GetStream(device_id);}
复制代码

这一段代码其实涉及的内容非常多,这里只能简单说一下,函数传进来的 instr_type_name 是"cpu.LocalCallOpKernel",先在 VirtualMachineEngine 的下面这个 map 成员查询这个 key:

std::map<std::string, RtInstrTypeId> instr_type_name2rt_instr_type_id_;
复制代码

这个 map 的 value type 是 RtInstrTypeId,从它可以得到 InstrTypeId 和相应的 Stream 指针,它定义位于 oneflow/core/vm/runtime_instr_type_id.h+25:

class RtInstrTypeId final { public:  RtInstrTypeId(const RtInstrTypeId&) = default;  RtInstrTypeId(RtInstrTypeId&&) = default;  ~RtInstrTypeId() = default;

RtInstrTypeId(const InstrTypeId& instr_type_id, StreamRtDesc* stream_rt_desc) : instr_type_id_(instr_type_id), stream_rt_desc_(stream_rt_desc) { if (stream_rt_desc->stream_type().IsControlStreamType()) { get_stream_ = &StreamRtDesc::GetSoleStream; } else { get_stream_ = &StreamRtDesc::GetDeviceStream; } }

const InstrTypeId& instr_type_id() const { return instr_type_id_; } Stream* GetStream(int device_id) const { return (stream_rt_desc_->*get_stream_)(device_id); }

private: const InstrTypeId instr_type_id_; StreamRtDesc* stream_rt_desc_; Stream* (StreamRtDesc::*get_stream_)(int device_id) const;};
复制代码

如果没有从这个 map 中找到"cpu.LocalCallOpKernel"这个 key,则会做下面操作:

if (unlikely(iter == cache->end())) {  const auto& instr_type_id_val = LookupInstrTypeId(instr_type_name);  const auto* stream_type = &instr_type_id_val.stream_type();  auto* stream_rt_desc = this->mut_stream_type2stream_rt_desc()->FindPtr(stream_type);  iter = cache->emplace(instr_type_name, RtInstrTypeId(instr_type_id_val, stream_rt_desc)).first;}
复制代码

先通过 LookupInstrTypeId 查询第三节注册的数据结构 C,从而找到"cpu.LocalCallOpKernel"相应的 InstrTypeId,它里面包含相关的 StreamTypeId 信息,再使用这个 StreamTypeId,通过调用 mut_stream_type_id2stream_rt_desc()->FindPtr 来找到对应的 StreamRtDesc 对象指针,然后根据 instr_type_id_val 和 stream_rt_desc 构造一个 RtInstrTypeId 对象作为 value,维护到前面的 map 中,最后再从这个 map 得到 InstrTypeId 和相应的 Stream 指针返回。

顺便说一下 mut_stream_type_id2stream_rt_desc()对应的数据结构,它在 VirtualMachineEngine 的__Init__函数中(构造的时候被调用)被初始化,位于 oneflow/core/vm/virtual_machine_engine.cpp+358:

void VirtualMachineEngine::__Init__(const VmDesc& vm_desc) {  ...  INTRUSIVE_UNSAFE_FOR_EACH_PTR(stream_desc, &vm_desc.stream_type_id2desc()) {    if (stream_desc->num_threads() == 0) { continue; }    auto stream_rt_desc = intrusive::make_shared<StreamRtDesc>(stream_desc);    mut_stream_type_id2stream_rt_desc()->Insert(stream_rt_desc.Mutable());    ...  }}
复制代码

这样就知道了构造好的 InstructionMsg 对象是怎么包含的 Stream 信息,继续看 InstructionMsg 是怎么转换为 Instruction 对象的,在前面 4.2 节中讲的 HandleLocalPending 函数,位于 oneflow/core/vm/virtual_machine_engine.cpp+62:

void VirtualMachineEngine::HandlePending() {  ...  InstructionMsgList pending_instr_msgs;  INTRUSIVE_FOR_EACH_PTR(instr_msg, &pending_instr_msgs) {    MakeInstructions(instr_msg, /*out*/ &new_instruction_list);  }  ...  INTRUSIVE_FOR_EACH_PTR(instruction, &new_instruction_list) {    ConsumeMirroredObjects(instruction);    if (likely(Dispatchable(instruction))) {      mut_ready_instruction_list()->PushBack(instruction);      new_instruction_list.Erase(instruction);    }  }}
复制代码

其中的 MakeInstructions 会做这个转换,它的定义位于 oneflow/core/vm/virtual_machine_engine.cpp+226,原来的 Stream 信息也会被维护到这个新的数据结构中:

void VirtualMachineEngine::MakeInstructions(InstructionMsg* instr_msg,                                            /*out*/ InstructionList* new_instruction_list) {  const auto& instruction_type = instr_msg->instr_type_id().instruction_type();  bool is_barrier_instruction = instruction_type.IsFrontSequential();  Stream* stream = CHECK_NOTNULL(instr_msg->phy_instr_stream());  const auto& pd = instr_msg->phy_instr_parallel_desc();  intrusive::shared_ptr<Instruction> instr = stream->NewInstruction(instr_msg, pd);  LivelyInstructionListPushBack(instr.Mutable());  if (unlikely(is_barrier_instruction)) {    mut_barrier_instruction_list()->PushBack(instr.Mutable());  } else {    new_instruction_list->PushBack(instr.Mutable());  }}
复制代码

以上就是第四节末尾代码调用 stream_type.Run()的时候,stream_type 的由来,由前面的分析可知,它的实际类型就是和 CpuLocalCallOpKernelInstructionType 建立好关联的 vm::CpuStreamType!下面继续看虚拟机的调度过程。

6、虚拟机调度过程 2

再继续看第四节的最后一段代码,为方便阅读,重新贴一下主要内容,位于 oneflow/core/vm/virtual_machine_engine.cpp+344:

void VirtualMachineEngine::DispatchInstruction(Instruction* instruction) {  ...  if (OnSchedulerThread(stream_type)) {    stream_type.Run(instruction);  } else {    stream->mut_thread_ctx()->mut_pending_instruction_list()->PushBack(instruction);    schedule_ctx.OnWorkerLoadPending(stream->mut_thread_ctx());  }  ...}
复制代码

从这个函数中可以看出,指令被 stream_type.Run 来执行了,从前面第五节的分析可知,stream_type 是 vm::CpuStreamType 类型,继承自 StreamType 类型,StreamType 定义于 oneflow/core/vm/stream_type.h,下面是它的主要接口:

class StreamType { public:  virtual ~StreamType() = default;  void Run(Instruction* instruction) const { Compute(instruction); }

virtual const char* stream_tag() const = 0; virtual void InitDeviceCtx(std::unique_ptr<DeviceCtx>* device_ctx, Stream* stream) const = 0; virtual void InitInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const = 0; virtual void DeleteInstructionStatus(const Stream& stream, InstructionStatusBuffer* status_buffer) const = 0; virtual bool QueryInstructionStatusDone(const Stream& stream, const InstructionStatusBuffer& status_buffer) const = 0; virtual void Compute(Instruction* instruction) const = 0; virtual intrusive::shared_ptr<StreamDesc> MakeStreamDesc(const Resource& resource, int64_t this_machine_id) const = 0; virtual bool OnSchedulerThread() const = 0; virtual bool SupportingTransportInstructions() const = 0; virtual bool IsControlStreamType() const { return false; }

protected: StreamType() = default;};
复制代码

这里面含有前面代码中用到的 Run 接口(stream_type.Run),它的实现位于 Compute 函数中。从 StreamType 的定义可以知道,这是一个虚接口,StreamType 有下面这些子类实现:

图 1

我们这里使用的是 CpuStreamType,定义位于 oneflow/core/vm/cpu_stream_type.h,它的 Compute 函数位于 oneflow/core/vm/cpu_stream_type.cpp+50,如下所示:

void CpuStreamType::Compute(Instruction* instruction) const {  ...  {    const auto& instr_type_id = instruction->mut_instr_msg()->instr_type_id();    instr_type_id.instruction_type().Compute(instruction);  }  auto* status_buffer = instruction->mut_status_buffer();  NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer()->mut_data())->set_done();  ...}
复制代码

可以看到这里又调用了 instr_type_id.instruction_type().Compute()这个函数,这个 Compute 属于 instruction_type()对应的类中,可以查到 instruction_type()会返回一个 InstructionType 类型的 const 引用对象,所以关注 InstructionType 类即可,它的定义位于 oneflow/core/vm/instruction_type.h,里面有 Compute 虚接口:

class InstructionType {  ...  virtual void Compute(Instruction* instruction) const = 0;  virtual void ComputeInFuseMode(InstructionMsg* instr_msg) const { LOG(FATAL) << "UNIMPLEMENTED"; }  ...};
复制代码

这也是个继承体系,InstructionType 有非常多的子类,下面是我找到的一部分示例,没有列完:

我们调用的 Compute 位于上图中的 LocalCallOpKernelInstructionType,位于 oneflow/core/eager/opkernel_instruction_type.cpp+150,它的 Compute 函数定义如下:

void LocalCallOpKernelInstructionType::Compute(vm::Instruction* instruction) const {  CHECK_JUST(LocalCallOpKernelUtil::Compute(instruction));}
复制代码

可见又继续调用了 LocalCallOpKernelUtil::Compute,继续追这个函数,它的定义位于 oneflow/core/eager/opkernel_instruction_type.cpp+44:

struct LocalCallOpKernelUtil final {  static inline Maybe<void> Compute(vm::Instruction* instruction) {    ...    OpKernelCompute(operand, device_ctx, state, cache);    ...    return Maybe<void>::Ok();  }  ...};
复制代码

这里又继续调用了 OpKernelCompute,在同一个类中:

struct LocalCallOpKernelUtil final {  ...  static inline void OpKernelCompute(LocalCallOpKernelPhyInstrOperand* operand,                                     DeviceCtx* device_ctx, user_op::OpKernelState* state,                                     const user_op::OpKernelCache* cache) {    ...    operand->user_opkernel()->Compute(compute_ctx, state, cache);    ...  }};
复制代码

其中 user_opkernel()会返回一个 user_op::OpKernel 的指针,而这个 OpKernel 就是我们定义算子的时候必须要继承的一个基类,以我们的 relu 示例来说,relu 的计算部分定义在 oneflow/user/kernels/relu_kernel.cpp,精简代码如下:

class ReluKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { private:  void Compute(user_op::KernelComputeContext* ctx) const override {    // do computing!  }};
复制代码

至此,终于从上到下打通了一条执行路线!


Reference


本文主要梳理了 OneFlow 虚拟机的的作用和相关实现,主要参考的是 OneFlow 的官方代码和之前的一些相关文章,但限于篇幅和本人目前的认知,里面有很多地方还没有弄懂或者没有总结,比如指令边的部分,SkipList、SkipListHead、ListHookArray、ListHook、SkipListHook 等基础数据结构的作用及实现细节等,需要继续学习的地方还有很多,继续加油~


下面是相关链接:


(本文参考代码:

https://github.com/Oneflow-Inc/oneflow/commit/888ad73fe28e2a4509ce7e563f196011e88b817d


特别感谢同事路强、俊丞、后江在我学习和理解这部分内容的过程中提供的帮助。


其他人都在看


欢迎下载体验 OneFlow v0.7.0 最新版本:https://github.com/Oneflow-Inc/oneflow/

发布于: 刚刚阅读数: 3
用户头像

OneFlow

关注

不至于成为世界上最快的深度学习框架。 2022.03.23 加入

★ OneFlow深度学习框架:github.com/Oneflow-Inc/oneflow ★ OF云平台:oneflow.cloud

评论

发布
暂无评论
OneFlow学习笔记:从OpExprInterpreter到OpKernel_数据结构_OneFlow_InfoQ写作社区