avatar
tkat0.dev
Published on

Distillerのthinningの仕様

Distillerは Pytorch 向けの pruning や quantization を行うライブラリだが、Channel Pruning 後に Weight を shrink してサイズを小さくする"thinning"ができる。 今回はそのソースを読んで気がついたことのまとめ。

これは昔書いた記事

Distiller で DeepLearning のモデルを軽量化: Gradual Pruning 編 - tkato’s blog http://tkat0.hateblo.jp/entry/2018/05/22/082911

Distiller とは PyTorch 向けのモデル圧縮ライブラリです。以下のような特徴があります。

  • 数種類の枝刈り(pruning), 量子化(quantization), 正則化(regularization)アルゴリズムを実装
  • 既存の学習スクリプトの training loop に追加するだけで使える
  • 設定は YAML。モデルのレイヤー単位で pruning のパラメータを変えるなど柔軟な設定。
  • TensorBoard と連携した、モデルの weight や精度の可視化

今回見る thinning だが、以下で説明している rebuild と同じもの。

https://www.slideshare.net/tkatojp/chainerchannel-pruning-125938007

対象とするソースは以下。

https://github.com/NervanaSystems/distiller/blob/e564a05f47a9e15e8575615d1ba92358b9184b67/distiller/thinning.py#L223-L227

この API で、model から値がゼロの filter を除外してコンパクトな weight を作る。

def remove_filters(model, zeros_mask_dict, arch, dataset, optimizer):
    sgraph = create_graph(dataset, arch)
    thinning_recipe = create_thinning_recipe_filters(sgraph, model, zeros_mask_dict)
    apply_and_save_recipe(model, zeros_mask_dict, thinning_recipe, optimizer)
    return model

arch や dataset を指定する時点でお察しだけど、cifar や imagenet 向け特定のモデルしか対応していない。 もちろん、内部で呼んでる関数をダイレクトにつなげれば任意のモデルを thinning できる(はずだった)。

中で何をやっているかを簡単に説明すると

  1. PyTorch のモデルをSummaryGraphクラスのオブジェクトへ変換する (create_graph)
    • 指定したレイヤーの後続するレイヤーを取得したり(successors_f())できるクラス
    • torch.jit.get_trace_graph で計算グラフをトレースし、その後 self.opsself.edgesといったデータ構造に入れる
    • 層ごとの計算量(MACs)もここで計算
  2. 値がゼロになるフィルターを除外し、その層に後続する Conv/FC/BN の入力も合わせて調整する(create_thinning_recipe_filters, apply_and_save_recipe

2 が微妙で、各層に対する実装がハードコーディングされている。

  • 例えば後続する層が GroupConv(depth-wise とか)の場合は対応していないので変換できない
  • 現在の層の直後 1 つしか考慮していない
    • conv-bn-conv-bn-fc という構造前提のアルゴリズム。大体のモデルはそうだろうけど。。。

なので、例えば MobileNet とかは thinning できないということ。ああ〜。

それで作ったわけですが(忙しくてメンテできてない)。

https://github.com/DeNA/ChainerPruner