Connect With Us

clotilde07@hermann.com

Call For Inquiry

45-9601175

Opening hours

Mon - Sun : 09:00 - 16:00

記事 Google DeepMindのNFNetがディープラーニングを効率化

Google DeepMindのNFNetがディープラーニングを効率化

Google DeepMindのNFNetがディープラーニングを効率化
491
630
700
369

原文(投稿日:2021/03/26)へのリンク

GoogleのDeepMind AI企業は最近NFNetsをリリースした。これは、ノーマライザーフリーResNet画像分類モデルであり、現在の最先端のEfficientNetよりも8.7倍速いトレーニングパフォーマンスを実現した。

GoogleのDeepMindの研究者によると次の通りである(以下のプロットをチェックしてください)。

NFNet-F1モデルは、EfficientNet-B7と同様の精度を実現している。一方で、トレーニングでは8.7倍高速である。我々の最大のモデルは、86.5%のトップ1精度の追加データなしで、新しい総合的な最先端技術に位置付けられる。

Google DeepMindのNFNetがディープラーニングを効率化

大規模な画像認識タスクの場合、通常、ニューラルネットワークはバッチ正規化と呼ばれる手法を使用し、それによってモデルトレーニングがより効率的になる。さらに、ニューラルネットワークがより一般化する助けとなる。つまり、正則化効果がある。

バッチ正規化には、トレーニング時間と推論時間の間の食い違いの動作や、計算オーバーヘッドといったいくつかの欠点があある。計算オーバーヘッドは、後のバックプロパゲーション(ニューラルネットワーク学習プロセス)に必要なネットワーク層ごとの特定のパラメーターの格納によるものである。

DeepMindは、方程式から正規化を削除し、トレーニングパフォーマンスを向上させるためにNFNetを導入した。これに加えて、適応勾配クリッピングと呼ばれる手法が導入されている。これによって、ResNetなどのニューラルネットワークモデルをより大きなバッチサイズで効率的にトレーニングできる。この方法では、同じ精度のEfficientNetと比較して、計算リソース(使用されるGPUの量)ごとにトレーニング時間が20~40%短縮された。

出典: 正規化なしの高性能大規模画像認識

コードはGoogleのDeepMind GitHubで公開され、JAXと呼ばれるこの新しいフレームワークに実装された。 NFNetでフォワードステップを実行するには、次のコードを実行するだけである。

def forward(inputs, is_training): model = nfnet.NFNet(num_classes=1000,variant=variant) return model(inputs, is_training=is_training)['logits']net = hk.without_apply_rng(hk.transform(forward))fwd = jax.jit(lambda inputs: net.apply(params, inputs, is_training=False))# We split this into two cells so that we don't repeatedly jit the fwd fn.logits = fwd(x[None]) # Give X a newaxis to make it batch-size-1which_class = imagenet_classlist[int(logits.argmax())]print(f'ImageNet class: {which_class}.')

NFNetsはPytorchにも実装されており、コミュニティがこのリリースを受け入れていることを示している。

import torchfrom torch import nn, optimfrom torchvision.models import resnet18from nfnets import WSConv2dfrom nfnets.agc import AGC # Needs testingconv = nn.Conv2d(3,6,3)w_conv = WSConv2d(3,6,3)optim = optim.SGD(conv.parameters(), 1e-3)optim_agc = AGC(conv.parameters(), optim) # Needs testing# Ignore fc of a model while applying AGC.model = resnet18()optim = torch.optim.SGD(model.parameters(), 1e-3)optim = AGC(model.parameters(), optim, model=model, ignore_agc=['fc'])

最後に、NFNetに関するYouTubeビデオの再生回数は30,000回を超えている。

タグ: aiはあなたの会社のインテリジェンスaiを助けますか