在h100发布之际,英伟达还带来一个“重磅产品”——transformer engine。在transformer大火之际推出这么一个产品,无疑是炼丹师福音。
当时我还在猜测它会以怎么样的一种形式呈现给用户,直到最近公开了仓库 nvidia/transformerengine
这其实就是pytorch的一个拓展,为了利用fp8的特性,针对transformer里面的kernel进行了重写,包含了一系列layernorm, gelu, scaledsoftmax等。
使用方式也是比较简单,使用该拓展额外包的一层module来搭建网络,即可,最后再包一层混合精度训练作用域:
import torchimport transformer_engine.pytorch as tefrom transformer_engine.common import recipe# set dimensions.in_features = 768out_features = 3072hidden_size = 2048# initialize model and inputs.model = te.linear(in_features, out_features, use_bias=true)inp = torch.randn(hidden_size, in_features, device=cuda)# 创建fp8训练的配置fp8_recipe = recipe.delayedscaling(margin=0, interval=1, fp8_format=recipe.format.e4m3)# fp8的autocastwith te.fp8_autocast(enabled=true, fp8_recipe=fp8_recipe): out = model(inp)loss = out.sum()loss.backward()
本篇博客就简单介绍下transformer engine及其对应实现原理
官方文档:https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
transfromer engine 是干啥的?
在各种以transformer为基础的语言模型如gpt3大火后,语言模型的参数量还在以指数形式增长:
那么优化transformer性能就显得格外重要了,其中混合精度训练是一个很实用的技巧
在fp16下,其数据范围还是足够大的,因此在amp下,我们只在最后的loss做了一个scaling,这个步骤足以保证在整个模型运算过程中不会产生溢出
而fp8相比fp16减少了更多有效位,因此不能简单地复用fp16下的策略,需要给每个fp8 tensor单独设置一个合适的scale factor。transformer engine 需要动态地对输入范围进行调整,如图所示:
上图来自h100白皮书内(当时我还天真的以为有一个专门的硬件做这个处理。。。)
下面我们简单看下其代码和实现原理
kernel实现
具体到每一个算子实现动态范围调整的原理其实很简单,通过记录历史的abs max值,来去调整最终缩放的范围。
其主要的kernel实现都放在了 common 目录下,我们以gelu这个kernel为例,最终它会调用到 vectorized_pointwise.h这个文件,我们主要看 unary_kernel
unary_kernel
这个核函数模板跟常规的elementwise向量化模板是类似的。
首先会让每个线程获取到scale值
computetype s = 0;if constexpr (is_fp8::value) { // 获取scale值 if (scale != nullptr) s = *scale; // 将scale取倒数写回scale_inv if (blockidx.x == 0 && threadidx.x == 0 && scale_inv != nullptr) { reciprocal(scale_inv, s); }}
其中在循环里,线程会不断更新他运算结果的最大值,并且最终运算结果要乘上scale值:
// 实际运算发生computetype temp = op(val, p);if constexpr (is_fp8::value) { __builtin_assume(max >= 0); max = fmaxf(fabsf(temp), max); // 缩放 temp = temp * s;}
当kernel主体运算完毕后,再也warp为单位做一个reduce_max,获取到线程束内的最大值,再通过atomicmax原子操作,不断更新全局最大值:
if constexpr (is_fp8::value) { /* warp tile amax reduce*/ max = reduce_max(max, warp_id); if (threadidx.x == 0 && amax != nullptr) { static_assert(std::is_same::value); // 更新全局最大值 atomicmaxfloat(amax, max); }}
其他layernorm等kernel也是诸如类似的逻辑,这里就不再展开了
python api
(1) delayedscaling
从前面的示例代码我们可以看到一个比较重要的api是delayedscaling,我们可以根据官方文档查看各个参数含义:
margin 计算scale的偏移量
interval 控制计算scale factor的频率
fp8_format 使用fp8的格式,fp8有e4m3和e5m2,但是现在不支持纯e5m2的格式训练
amax_history_len 记录abs maxval的历史窗口大小
amax_compute_algo 在窗口里选择absmax的算法,'max'则是选择历史窗口里最大值,'most_recent'则是选择近期的值,当然你也可以传一个自定义的函数
相关代码为:
@torch.jit.scriptdef _default_get_amax( amax_history: torch.tensor, amax_compute_algo: str,) -> tuple[torch.tensor, torch.tensor]: default function to obtain amax from history. if amax_compute_algo == max: amax = torch.max(amax_history, dim=0).values else: # amax_compute_algo == most_recent amax = amax_history[0] amax_history = update_amax_history(amax_history) return amax_history, amax
scaling_factor_compute_algo 计算scale factor的算法
@torch.jit.scriptdef _default_sf_compute( amax: torch.tensor, scale: torch.tensor, fp8_max: float, margin: int,) -> torch.tensor: default function to convert amax to scaling factor. exp = torch.floor(torch.log2(fp8_max / amax)) - margin sf = torch.round(torch.pow(2, torch.abs(exp))) sf = torch.where(amax > 0.0, sf, scale) sf = torch.where(torch.isfinite(amax), sf, scale) sf = torch.where(exp none: ... self.fp8 = false self.fp8_meta = {} self.fp8_meta[fp8_group] = none self.fp8_meta[recipe] = get_default_fp8_recipe() def fp8_init(self, num_gemms: int = 1) -> none: initialize fp8 related metadata and tensors during fprop. # if fp8 isn't enabled, turn off and return. if not is_fp8_enabled(): self.fp8 = false return # fp8 is already enabled and recipe is the same, don't do anything. if self.fp8 and get_fp8_recipe() == self.fp8_meta[recipe]: return # set fp8, recipe, and other fp8 metadata self.fp8 = true self.fp8_meta[recipe] = get_fp8_recipe() self.fp8_meta[num_gemms] = num_gemms self.fp8_meta[fp8_group] = get_fp8_group() # set fp8_max per tensor according to recipe self.fp8_meta[fp8_max_fwd] = self.fp8_meta[recipe].fp8_format.value.max_fwd self.fp8_meta[fp8_max_bwd] = self.fp8_meta[recipe].fp8_format.value.max_bwd # allocate scales and amaxes self.init_fp8_meta_tensors()
而相关module如layernormmlp继承该module,并且传入fp8_meta信息更新:
class layernormmlp(transformerenginebasemodule): def forward(...): out = _layernormmlp.apply( ..., self.fp8, self.fp8_meta, )
总结
大致浏览完其实思路不复杂,但感觉还是fp8技术的不稳定,整个项目还是加入了很多限制。得益于pytorch灵活的外部扩展形式,只要不去触碰框架底层运行机制,仅仅在算子层面上的修改还是相当简单。虽然不具备通用性,但是运算主体就这几个算子,为了性能也是可以接受的
配电箱远程智能监控方案
主流无线技术的选型困局
重点领域人工智能治理挑战及对策研究工作研讨会 在线上成功举办
在AEIMR五大行业中寻找国产智能芯片“同时起跑”的新机遇
在工业领域,用嵌入式还是用PLC?
详解NVIDIA H100 TransformerEngine
这才是完整高通骁龙835旗舰,努比亚Z17将全球首发QC4快充!
数据中心电源路径问题解决方案
车载显示屏的安装方式
九齐单片机的应用领域及优势
汽车革命逼近 或是未来的智能手机?
6G到来之后的生活会是怎样的
中兴通讯分布式精准云助力运营商掌控云网生态
单片机按键的介绍独立按键与矩阵键盘的概述
电脑的CPU风扇应该如何安装详细教程说明
5G毫米波成为MWC 2021上海展众多目光汇聚的热点之一
人工智能的迅速发展 正在全方位深刻改变人们的生产生活
高通骁龙835成本价曝光,要不雷军说小米6成本高呢?
哈曼将利用Innoviz的激光雷达产品进一步巩固技术领先供应商的地位
纵联差动保护与横差保护的区别