热门
最新
红包
立Flag
投票
同城
我的
发布
CSDN App 扫码分享
评论
1
打赏
- 复制链接
- 举报
下一条:
8-bit 和 4-bit ADAM 实现现在已经在 torchao 中可用,感谢 @miamia75 https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim。相对于 fp32,它们分别将你的优化器状态减少了 4 倍和 8 倍。这些实现完全用 PyTorch 编写,然后通过 torch.compile 以实现与经过专家编写的 CUDA 内核性能相当的表现。你可以将它们作为直接替换使用:```pythonfrom torchao.prototype.low_bit_optim import Adam4bit, Adam8bitoptim = Adam8bit(model.parameters())optim = Adam4bit(model.parameters())```减少优化器内存是一件大事,因为 ADAM 需要比你的模型参数多 2 倍的 VRAM!在另一端,我们有 SGD,它不需要额外的 VRAM,但不收敛。多年来也提出了其他更节省内存的优化器,但还没有像 ADAM 一样经受住时间的考验。这项工作是站在巨人的肩膀上的,我们非常感谢 https://github.com/TimDettmers/bitsandbytes 和 https://github.com/thu-ml/low-bit-optimizers。请注意,我们的代码仍然比他们的专家级 CUDA 内核慢大约 20%,但我们希望你会发现我们的代码作为一个可修改的纯 Python 实现是有用的。一些简单的变体可能具有更好的收敛性,所以请复制粘贴我们的代码并进行分叉,如果你希望我们为你打包,请考虑向上游提交!如果你已经读到这里,那么我会分享一些有趣的技术细节:1. 转换为较小的 dtypes,尤其是像 int4,这对原始 dtypes 中的异常值非常敏感,所以转换通常按块中的最大值进行缩放。如果这句话让你感到困惑,你可以通过在此处添加打印语句来浏览完整的量化算法 https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim2. 在确定如何将张量转换为较低 dtype 时,通常会有归一化的浮点数,而不是均匀分布在一个范围内的浮点数。在这种情况下,运行二分搜索算法是有益的,但编写 Triton 中的二分搜索例程相当困难,所以人们通常依赖 CUDA。相反,我们发现可以在纯 PyTorch 中实现二分搜索,其中“索引”在二分搜索中总是向上递增。了解更多请点击这里 https://github.com/pytorch/ao/pull/478```pythoninput = input.view(-1)codes = torch.where(input >= qmap[8], 8, 0)codes += torch.where(input >= qmap[codes + 4], 4, 0)codes += torch.where(input >= qmap[codes + 2], 2, 0)codes += torch.where(input >= qmap[codes + 1], 1, 0)```3. 对于 int8,我们在 PyTorch 中已经有一个类型,但没有 int4 dtype,所以我们可以用这一行将 2 个 int4 位打包成一个 int8,并依赖编译器通过位打包或解包来融合计算。```pythoncodes = (codes[::2] << 4) | codes[1::2]```4. 优化器支持作为张量子类 https://github.com/albanD/subclass_zoo,这是一个新的 PyTorch 特性,最大限度地提高了与其他 PyTorch 子系统的组合能力。特别是新的优化器需要实现一个 `lerp` 操作,这用于实现 `opt.step() -> out = start + weight x (end - start)`。请记住,这里完全没有 C++,所以如果你一直想进入量化研究但被 CUDA 阻碍,请查看我们 https://github.com/pytorch/ao。