TensorFlow中的并行执行引擎——StreamExecutor框架

 

背景

[作者:

StreamExecutor为TensorFlow的执行层面提供了较为统一的抽象,而在底层各种Device的执行管理细节却完全不同。我们可以看到stream_executor下面有cuda和host两个子目录,他们分别是GPU执行引擎和CPU执行引擎所使用的子模块。下面我们先从统一的抽象层面来梳理该框架的结构。

StreamExecutor对外提供的句柄——Stream对象

为了隐藏StreamExecutor框架管理的复杂性,它对外暴露的handler必须足够简单。事实也确实如此,StreamExecutor通过暴露Stream对象作为操作底层的handler。一般而言,在TensorFlow的框架中都是使用Stream对象来调用底层计算库,进行设备间数据拷贝操作等过程。比如调用Stream对象的ThenMemcpy即可完成异步的数据传输拷贝过程,调用ThenConvolveXXX等函数即可完成DNN库中的卷积调用。事实上,TensorFlow中很多Op的C++实现中,其Compute函数内就是通过使用Stream对象来完成某些实际计算或数据拷贝的过程,下图展示了Stream对象、StreamExecutor框架以及其他模块的关系。

Stream对象是通过持有StreamInterface的具体实现对象来获得实际平台的Stream,进而通过Stream这个统一的handler完成与底层的交互,下面试这一子模块的类图结构。

 

StreamExecutor框架内的层次结构

熟悉GPU编程的同学都知道,CUDA程序的编写是相对复杂的,不但要针对某种任务设计特定的并行编程思路,还要管理Event,Stream等较为底层的对象。为了能够减轻StreamExecutor用户的使用负担,也为了能够给上层调用者即TensorFlow引擎提供更加统一的接口,一些抽象分层的工作是非常有必要的。总体上StreamExecutor框架由三个层次组成,从上到下依次为Platform层(平台描述)、StreamExecutor Core层(执行引擎)和LibrarySupport层(基础库)。如果需要为TensorFlow添加新的计算设备种类,不但要向TensorFlow中注册Device的定义,还需要在StreamExecutor框架中提供负责管理该Device计算的代码。

Platform层

在StreamExecutor中Platform指的是计算所使用设备平台的抽象,每种Device对应一种Platform。比如GPU对应的是CudaPlatform,而CPU对应的是HostPlatform等。一旦获得了某种Device的Platform,就可以获取和该Platform对应的StreamExecutor Core以及相应的LibrarySupport。在TensorFlow的代码实现中,所有Platform类都是通过宏定义和MultiPlatformManager管理类的静态方法主动注册到系统中的,下面是这一层次的类图表示。

CudaPlatform和HostPlatform继承自公共父类Platform,如果有新的Platform出现,依然可以沿用这样的设计直接继承并给出实现。所有的Platform都通过MultiPlaftormManager调用RegsiterPlatform函数主动注册到系统中并做初始化,下面代码段是CudaPlaftorm的注册过程,注册使用了Initializer模块及相应的宏定义,这些代码比较简单,这里就不再详细展开了。

复制代码
 1 static void InitializeCudaPlatform() {  2   // Disabling leak checking, MultiPlatformManager does not destroy its  3   // registered platforms. 4  5   std::unique_ptr<cuda::CudaPlatform> platform(new cuda::CudaPlatform);  6   SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform)));  7 }  8  9 }  // namespace stream_executor10 11 REGISTER_MODULE_INITIALIZER(cuda_platform, 12                             stream_executor::InitializeCudaPlatform()); 13 14 // Note that module initialization sequencing is not supported in the 15 // open-source project, so this will be a no-op there.16 REGISTER_MODULE_INITIALIZER_SEQUENCE(cuda_platform, multi_platform_manager); 17 REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, 18                                      cuda_platform);
复制代码

MultiPlatformManager提供了两种获取具体Platform的方式,一种是通过name,另一种是通过Id,如下代码段所示。

复制代码
 1   // Retrieves the platform registered with the given platform name (e.g.  2   // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the  3   // Platform's Id() method).  4   // 5   // If the platform has not already been initialized, it will be initialized  6   // with a default set of parameters.  7   // 8   //
                        
关键字:
50000+
5万行代码练就真实本领
17年
创办于2008年老牌培训机构
1000+
合作企业
98%
就业率

联系我们

电话咨询

0532-85025005

扫码添加微信