e.blog

主にUnity/UE周りのことについてまとめていきます

バイトニックソートの実装を理解する

概要

以前書いた粒子法を用いた流体シミュレーションをさらに発展させ近傍探索を行って最適化をしています。
その中で使っている『バイトニックソート』というソートについてまとめたいと思います。

本記事は近傍探索を実装する上でのサポート的な記事です。
近いうちに近傍探索の実装についても書こうと思っています。

なお、参考にさせていただいた記事は流体シミュレーション実装を参考にさせていただいた@kodai100さんが書いている記事です。

qiita.com

内容は近傍探索についてですがその中でバイトニックソートについての言及があります。

ちなみに流体シミュレーション自体についても記事を書いているので興味がある方はご覧ください。

edom18.hateblo.jp



バイトニックソートとは

Wikipediaによると以下のように説明されています。

バイトニックマージソート(英語: Bitonic mergesort)または単にバイトニックソート(英語: Bitonic sort)とは、ソートの並列アルゴリズムの1つ。ソーティングネットワークの構築法としても知られている。

このアルゴリズムはケン・バッチャー(英語: Ken Batcher)によって考案されたもので、このソーティングネットワークの計算量はn個の要素に対して、サイズ(コンパレータ数=比較演算の回数)は O(n log^{2}(n))、段数(並列実行不可能な数)は O(log^{2}(n))となる[1]。各段での比較演算(n/2回)は独立して実行できるため、並列化による高速化が容易である。

自分もまだしっかりと理解できてはいませんが、ソーティングネットワークを構築することで配列の中身を見なくともソートができる方法のようです。

この配列の中を見なくてもというのがポイントで、決められた順に処理を実行していくだけでソートが完了します。
言い換えると並列に処理が可能ということです。

Wikipediaでも

並列化による高速化が容易である。

と言及があります。

GPU(コンピュートシェーダ)によって並列に計算する必要があるためこの特性はとても重要です。

ロジックを概観する

まずはWikipediaに掲載されている以下の図を見てください。

f:id:edo_m18:20200916090952p:plain

最初はなんのこっちゃと思いましたが、ひとつずつ見ていけばむずかしいことはしていません。

まずぱっと目に着くのは線と色のついた各種ボックスだと思います。
この図が言っているのは配列の中身が16要素あり、それをソートしていく様を示しています。

一番左が初期状態で、一番右がソートが完了した状態です。
よく見ると横に長く伸びる線が16本あることに気づくと思います。
これが配列の要素数を表しています。


余談ですが、なぜこういう図なのかと言うと。
ソーティングネットワーク自体がこういう概念っぽいです。

横に伸びる線をワイヤー、矢印の部分がコンパレータと呼ばれます。
そしてワイヤーの左からデータを流すと、まるであみだくじの要領でソートが完了します。
そのためにこういう図になっているというわけなんですね。


ソートされていく様子はあみだくじを想像するといいかもしれません。
左からデータが流れてきて、決まったパターンでデータが入れ替わっていき、最後にソートが完了している、そんなイメージです。

このデータが入れ替わっていく部分は矢印が表しています。
また矢印の向きは降順・昇順どちらに値を入れ替えるかを表しています。

並列可能部分と不可能部分

まず注目すべきはブロックによって区切られている点です。
以下の図をご覧ください。

f:id:edo_m18:20200918090923p:plain

メインブロックと書かれたところが大きく4つに分かれています。
そしてそのメインブロックの中にサブブロック郡があります。

ここで注意する点は並列計算可能な部分並列計算不可能な部分がある点です。
図を見てもらうと分かりますが、メインブロック内の矢印に着目するとそれぞれは独立して処理を行えることが分かります。
どの矢印から処理を開始しても結果は変わりません。

しかしメインブロック自体の計算順序を逆にしてしまうと結果が異なってしまいます。
これは並列実行できないことを意味しています。
今回の目的はGPUによって並列計算を行わせることなのでここの把握は重要です。

つまりメインブロックは並列不可、サブブロックは並列可ということです。

計算回数

次にブロックの処理順について法則を見てみましょう。

メインブロックは全部で4つあります。そして配列数は16です。
この関連性は 2^4=16)から来ています。

これは推測ですが、バイトニックの名前の由来はこの2進数から来ているのかもしれません。

さて、ではサブブロックはどうでしょうか。
サブブロックにも法則があります。
それは左から順に1, 2, 3, 4, ...と数が増えていることです。

この法則はコードを見てみると分かりやすいです。

for (int i = 0; i < 4; i++)
{
    for (int j = 0; j <= i; j++)
    {
        // ソート処理
    }
}

外側のforループがメインブロックのループを表していて、内側のループがサブブロックのループを表しています。
そして内側のループは外側のループが回るたびに回数が増えていく形になっています。

外側のループが1回なら内側も1回だけ実行され、外側の2ループ目は内側は2回ループする、という具合です。
なので外側のループが回るたびに内側のループの回数が増加していくというわけなんですね。

計算最大回数

今回は要素数16なので4でしたが、これが32なら5回ループが回るということですね。
もちろん、内側のループもそれに応じて増えていきます。
メインブロックの最大計算回数は素数2の何乗かに依るわけです。

ちなみに感の良い方ならお気づきかもしれませんが、2のべき乗で計算がされるということはそれ以外の要素数ではソートが行えないことを意味しています。
なのでもし要素数が2のべき乗以外の数になる場合はダミーデータなどを含めて2のべき乗に揃える必要があります。

矢印の意味

さらに詳細を見ていきましょう。

次に見るのは矢印です。
この矢印は配列内の要素を入れ替える(Swapする)ことを意味しています。
矢印なので向きがありますね。これは昇順・降順どちらに入れ替えるかを示しています。

よく見ると青いブロック内は昇順、緑のブロック内は降順に入れ替わっていることが分かります。
そして図の通りに入れ替えを進めていくと最終的にソートが完了している、というのがバイトニックソートです。

ひとつの解説だけだと解像度が足らないので別の記事でも探してみると、以下の記事と画像が理解を深めてくれました。

seesaawiki.jp

画像を引用させてもらうと以下のような感じでソートが進んでいきます。

f:id:edo_m18:20200915093123p:plain f:id:edo_m18:20200915093532p:plain

言っていることはWikipediaと同じですが実際の数値が並び変えられていくのでより理解が深まるかと思います。

ちなみにこの入れ替え手順先に示したコード通りになっているのが分かります。
各配列の下に添えられている数字を見ると2のべき乗の部分が0, 1, 0, 2, 1, 0と変化しているのが分かると思います。
これをグループ化して見てみると[0], [1, 0], [2, 1, 0]ということですね。
外側のループ回数が増えるにつれて内側のループが増えていくということと一致しています。

比較する対象の距離と方向を求める

さて、ループについては把握できたかと思います。
次に見るのはどの要素同士を入れ替えるかという点です。

入れ替える距離はループ数によって決まる

ループの仕方が分かっても、闇雲に配列の内容を入れ替えたのでは当然ソートはできません。
ではどういうルールで入れ替えていけばいいのでしょうか。

その答えは以下の計算です。

public static void Kernel(int[] a, int p, int q)
{
    int d = 1 << (p - q);
    
    for (int i = 0; i < a.Length; i++)
    {
        bool up = ((i >> p) & 2) == 0;
        
        if ((i & d) == 0 && (a[i] > a[i | d]) == up)
        {
            int t = a[i];
            a[i] = a[i | d];
            a[i | d] = t;
        }
    }
}

public static void BitonicSort(int logn, int[] a)
{
    for (int i = 0; i < logn; i++)
    {
        for (int j = 0; j <= i; j++)
        {
            Kernel(a, i, j);
        }
    }
}

距離に関しては以下の式で求めています。

// distanceのd
int d = 1 << (p - q);

ここでpは外側のループ、qは内側のループの回数が渡ってきます。
これを引き算の部分だけ見てみると、ループが進むに連れて以下のように計算されます。

0 - 0 = 0
1 - 0 = 1
1 - 1 = 0
2 - 0 = 2
2 - 1 = 1
2 - 2 = 0
...

引き算の結果は1bitをどれだけ左にシフトするかの数値なので、つまりは2を何乗するかを示しているわけですね。
これを把握した上で改めて先ほどの図を見てみると、確かにそう変化していっているのが分かると思います。

f:id:edo_m18:20200915093123p:plain f:id:edo_m18:20200915093532p:plain

昇順・降順は2bit目が立っているかで切り替える

昇順・降順を決めている計算は以下の部分です。

for (int i = 0; i < a.Length; i++)
{
    bool up = ((i >> p) & 2) == 0;

    // 後略
}

iは要素数分ループする回数を示しています。そしてpは前述の通り、外側のループ、つまりメインブロックの計算回数を示しています。

つまり、全要素をループさせ、かつそのループ回数をメインブロックの計算回数値だけ右にシフトし、そのときのビット配列の2bit目が立っているか否かで昇順・降順を切り替えているわけですね。(ちなみに2bit目が立っている場合は降順

ちょっとしたサンプルを書いてみました。
以下のpaizaのコードを実行すると、上の図の昇順・降順の様子と一致していることが分かるかと思います。
(サンプルコードの↓が昇順、↑が降順を表しています)

Swap処理

最後に、どの場合にどこと入れ替えるかの処理について見てみましょう。

bool up = ((i >> p) & 2) == 0;

if ((i & d) == 0 && (a[i] > a[i | d]) == up)
{
    int t = a[i];
    a[i] = a[i | d];
    a[i | d] = t;
}

最初の行のupのは昇順か降順かのフラグです。
続くif文が実際にSwapをするかを判定している箇所になります。

条件が2つ書かれているのでちょっと分かりづらいですが、分解してみると以下の2つを比較しています。

// ひとつめ
(i & d) == 0
// ふたつめ
(a[i] > a[i | d]) == up

ひとつめは要素の位置とdとの理論積になっていますね。

ふたつめは、配列の要素のふたつの値を比較し、upフラグの状態と比較しています。
これは昇順か降順かを判定しているに過ぎません。

問題は右側の要素へのアクセス方法ですね。
a[i | d]はなにをしているのでしょうか。

これらが意味するところは以下の記事がとても詳しく解説してくれています。

qiita.com

この記事から引用させてもらうと、

あるインデックスに対してそれと比較するインデックスはdだけ離れています。そのため2つのインデックスをビットで考えると値はp - qビット目の値が0か1かの違いだけになります(一番右端のビットを0ビット目として数えています)。配列の先頭に近いほうにあるインデックスをiとすると比較対象のインデックスはp - qビット目が1になるのでそのインデックスはi | dになります。つまり、if文内の(i & d) == 0は配列の先頭に近いほうにあるインデックスかどうかを確認しており、x[i] > x[i | d]で2つの値の大小を確認していることになります。

と書かれています。

文章だけだとちょっと分かりづらいですが、実際にbitを並べて図解してみると分かりやすいと思います。
試しに要素数8の場合で書き下してみると、

000 = 0
001 = 1
010 = 2
011 = 3
100 = 4
101 = 5
110 = 6
111 = 7

というふうになります。値の意味は配列の添字です。(要素数8なので0 ~ 7ということです)

以下の部分を考えてみましょう。

あるインデックスに対してそれと比較するインデックスはdだけ離れています。

仮にd = 1 << 0だとするとdの値は1です。つまりひとつ隣ということですね。

000 = 0 ┐
001 = 1 ┘
010 = 2 ┐
011 = 3 ┘

比較する対象はこうなります。そして引用元では、

そのため2つのインデックスをビットで考えると値はp - qビット目の値が0か1かの違いだけになります(一番右端のビットを0ビット目として数えています)。

と書かれています。上の例ではp - q == 0としているので、つまりは一番右側(0番目)のビットの違いを見れば良いわけです。
見てみると確かに違いは0ビット目の値の違いだけであることが分かります。

冗長になるのでこれ以上深堀りはしませんが、実際に書き下してみると確かにその通りになるのが分かります。
そしてここを理解するポイントは以下です。

  • 比較対象のうち、配列の先頭に近い方のインデックスの場合のみ処理する
  • 先頭に近いインデックスだった場合は、そのインデックスとそのインデックスからdだけ離れた要素と比較する

ということです。
まぁ細かいことは置いておいても、for文で全部を処理している以上、重複してしまうことは避けられないので、それをビットの妙で解決しているというわけですね。

これを一言で言えば、先頭のインデックスだった場合は、そのインデックスとdだけ離れたインデックス同士を比較するということです。

コード全体

最後にコード全体を残しておきます。
以下はC#で実装した例です。コード自体はWikipediaJavaの実装をそのまま移植したものです。

// This implementation is refered the WikiPedia
//
// https://en.wikipedia.org/wiki/Bitonic_sorter
public static class Util
{
    public static void Kernel(int[] a, int p, int q)
    {
        int d = 1 << (p - q);
        
        for (int i = 0; i < a.Length; i++)
        {
            bool up = ((i >> p) & 2) == 0;
            
            if ((i & d) == 0 && (a[i] > a[i | d]) == up)
            {
                int t = a[i];
                a[i] = a[i | d];
                a[i | d] = t;
            }
        }
    }
    
    public static void BitonicSort(int logn, int[] a)
    {
        for (int i = 0; i < logn; i++)
        {
            for (int j = 0; j <= i; j++)
            {
                Kernel(a, i, j);
            }
        }
    }
}

public class Example
{
    public static void Main()
    {
        int logn = 5, n = 1 << logn;
        
        int[] a0 = new int[n];
        System.Random rand = new System.Random();
        for (int i = 0; i < n; i++)
        {
            a0[i] = rand.Next(n);
        }
        
        for (int k = 0; k < a0.Length; k++)
        {
            System.Console.Write(a0[k] + " ");
        }
        
        Util.BitonicSort(logn, a0);
        
        System.Console.WriteLine();
        
        for (int k = 0; k < a0.Length; k++)
        {
            System.Console.Write(a0[k] + " ");
        }
    }
}

実際に実行する様子は以下で見れます。

まとめ

まとめると、バイトニックソートは以下のように考えることができます。

  • 比較する配列の要素は常にふたつ
  • 比較対象はビット演算によって求める
  • 昇順・降順の判定もビットの立っている位置によって決める
  • 配列の要素の比較重複(0 -> 1と0 <- 1という向きの違い)もビットの位置によって防ぐ

分かってしまえばとてもシンプルです。
が、理解もできるし使うこともできるけれど、これを思いつくのはどれだけアルゴリズムに精通していたらできるんでしょうか。
こうした先人の知恵には本当に助けられますね。

XRCameraSubSystemから直接カメラの映像を取得する

概要

今回はUnityのARFoundationが扱うシステムからカメラ映像を抜き出す処理についてまとめたいと思います。
これを利用する目的は、カメラの映像をDeep Learningなどに応用してなにかしらの出力を得たいためです。

今回の画像データの取得に関してはドキュメントに書かれているものを参考にしました。

docs.unity3d.com

今回のサンプルを録画したのが以下の動画です。

今回のサンプルはGitHubにアップしてあるので、詳細が気になる方はそちらをご覧ください。

github.com

全体の流れ

画像を取得する全体のフローを以下に示します。

  1. XRCameraImageを取得する
  2. XRCameraImage#Convertを利用してデータを取り出す
  3. 取り出したデータをTexture2Dに読み込ませる
  4. Texture2Dの画像を適切に回転しRenderTextureに書き出す

という流れになります。

ということでひとつずつ見ていきましょう。

XRCameraImageを取得し変換する

ここではXRCameraImageからデータを取得し、Texture2Dに書き込むまでを解説します。

まずはARのシステムからカメラの生データを取り出します。
ある意味でこの工程が今回の記事のほぼすべてです。

取り出したらDeep Learningなどで扱えるフォーマットに変換します。

今回実装したサンプルのコードは以下の記事を参考にさせていただきました。

qiita.com

ドキュメントの方法でも同様の結果を得ることができますが、テクスチャの生成を制限するなど最適化が入っているのでこちらを採用しました。

以下に取り出し・変換する際のコード断片を示します。

private void RefreshCameraFeedTexture()
{
    // TryGetLatestImageで最新のイメージを取得します。
    // ただし、失敗の可能性があるため、falseが返された場合は無視します。
    if (!_cameraManager.TryGetLatestImage(out XRCameraImage cameraImage)) return;

    // 中略

    // デバイスの回転に応じてカメラの情報を変換するための情報を定義します。
    CameraImageTransformation imageTransformation = (Input.deviceOrientation == DeviceOrientation.LandscapeRight)
        ? CameraImageTransformation.MirrorY
        : CameraImageTransformation.MirrorX;

    // カメライメージを取得するためのパラメータを設定します。
    XRCameraImageConversionParams conversionParams =
        new XRCameraImageConversionParams(cameraImage, TextureFormat.RGBA32, imageTransformation);

    // 生成済みのTexture2D(_texture)のネイティブのデータ配列の参照を得ます。
    NativeArray<byte> rawTextureData = _texture.GetRawTextureData<byte>();

    try
    {
        unsafe
        {
            // 前段で得たNativeArrayのポインタを渡し、直接データを流し込みます。
            cameraImage.Convert(conversionParams, new IntPtr(rawTextureData.GetUnsafePtr()), rawTextureData.Length);
        }
    }
    finally
    {
        cameraImage.Dispose();
    }

    // 取得したデータを適用します。
    _texture.Apply();

    // 後略
}

Texture2Dの画像を適切に回転しRenderTextureに書き出す

前段でXRCameraImageからデータを取り出しTexture2Dへ書き出すことができました。
ただ今回は最終的にTensorFlow Liteで扱うことを想定しているのでRenderTextureに情報を格納するのがゴールです。

ぱっと思いつくのはGraphics.Blitを利用してRenderTextureにコピーすることでしょう。
しかし、取り出した画像は生のデータ配列のため回転を考慮していません。(つまりカメラからの映像そのままということです)

以下の質問にUnityの中の人からの返信があります。

forum.unity.com

TryGetLatestImage exposes the raw, unrotated camera data. It will always be in the same orientation (landscape right, I believe). The purpose of this API is to allow for additional CPU-based image processing, such as with OpenCV or other computer vision library. These libraries usually have a means to rotate images, or accept images in various orientations, so we there is no built-in functionality to rotate the image.

要は、だいたいの場合において利用する対象(OpenCVなど)に回転の仕組みやあるいは回転を考慮しないでそのまま扱える機構があるからいらないよね、ってことだと思います。

そのため、人が見て適切に見えるようにするためには画像を回転してコピーする必要があります。
ですが心配いりません。処理自体はとてもシンプルです。

基本的には時計回りに90度回転させるだけでOKです。

なにも処理しない画像をQuadに貼り付けると以下のような感じで90度回転したものが出力されます。
(ちょっと分かりづらいですが、赤枠で囲ったところはAR空間に置かれたQuadで、そこにカメラの映像を貼り付けています)

f:id:edo_m18:20200823113553p:plain

これを90度回転させるためにはUVの値を少し変更するだけで達成することができます。

まずはシェーダコードを見てみましょう。

シェーダで画像を回転させる

見てもらうと分かりますが、基本はシンプルなImage EffectシェーダでUVの値をちょっと工夫しているだけです。

Shader "Hidden/RotateCameraImage"
{
    Properties
    {
        _MainTex ("Texture", 2D) = "white" {}
    }
    SubShader
    {
        Cull Off ZWrite Off ZTest Always

        Pass
        {
            CGPROGRAM
            #pragma vertex vert
            #pragma fragment frag

            #include "UnityCG.cginc"

            struct appdata
            {
                float4 vertex : POSITION;
                float2 uv : TEXCOORD0;
            };

            struct v2f
            {
                float2 uv : TEXCOORD0;
                float4 vertex : SV_POSITION;
            };

            v2f vert (appdata v)
            {
                v2f o;
                o.vertex = UnityObjectToClipPos(v.vertex);
                o.uv = v.uv;
                return o;
            }

            sampler2D _MainTex;

            fixed4 frag (v2f i) : SV_Target
            {
                float x = 1.0 - i.uv.y;
                float y = i.uv.x;
                float2 uv = float2(x, y);
                fixed4 col = tex2D(_MainTex, uv);
                return col;
            }
            ENDCG
        }
    }
}

x, yを反転して、さらにyの値を1.0から引いているだけです。簡単ですね。

そしてこのシェーダを適用したマテリアルを用いてGraphics.Blitを実行してやればOKです。

Graphics.Blit(texture, _previewTexture, _transposeMaterial);

分かりやすいように、グリッドの画像で適用したものを載せます。

f:id:edo_m18:20200823112925p:plain

90度右に回転しているのが分かるかと思います。

これで無事、画像が回転しました。

バイスの回転を考慮する

実は上のコードだけでは少し問題があります。
バイスの回転によって取得される画像データの見栄えが変わってしまうのです。

というのは、Portraitモードでは回転しているように見える画像でも、Landscapeモードだとカメラからの映像と見た目が一致して問題なくなるのです。
以下の動画を見てもらうと分かりますが、Portraitモードでは90度回転しているように見える画像が、Landscapeモードでは適切に見えます。

f:id:edo_m18:20200823114155g:plain

結論としてはPortraitモードのときだけ処理すればいいことになります。

private void PreviewTexture(Texture2D texture)
{
    if (_needsRotate)
    {
        Graphics.Blit(texture, _previewTexture, _transposeMaterial);
    }
    else
    {
        Graphics.Blit(texture, _previewTexture);
    }

    _renderer.material.mainTexture = _previewTexture;
}

バイスが回転した際のイベントが実はUnityには用意されていないようで、以下の記事を参考に回転の検知を実装しました。
(まぁゲームにおいて回転を検知してなにかをする、っていうケースが稀だからでしょうかね・・・)

forum.unity.com

カメライメージを取得するタイミング

最後にカメライメージの取得タイミングについて書いておきます。
ドキュメントにも書かれていますが、ARCameraManagerのframeReceivedというイベントのタイミングでカメライメージを取得するのが適切なようです。

ARCameraManager#frameReceivedイベント

ARCameraManagerにはframeReceivedというイベントがドキュメントでは以下のように説明されています。

An event which fires each time a new camera frame is received.

カメラフレームを受信したタイミングで発火するようですね。
なのでこのタイミングで最新のカメラデータを取得することで対象の映像を取得することができるというわけです。

ということで、以下のようにコールバックを設定してその中で今回の画像取得の処理を行います。

[SerializeField] private ARCameraManager _cameraManager = null;

// ---------------------------

_cameraManager.frameReceived += OnCameraFrameReceived;

// ---------------------------

private void OnCameraFrameReceived(ARCameraFrameEventArgs eventArgs)
{
    RefreshCameraFeedTexture();
}

最後に

無事、カメライメージを取得して扱える状態に変換することができました。
Texture2DとしてもRenderTextureとしても扱えるので用途に応じて使うといいでしょう。

気になる点としてはパフォーマンスでしょうか。
一度CPUを経由しているのでそのあたりが気になるところです。(まだ計測はしていませんが・・・)

が、シンプルな今回のデモシーンでは特に重さは感じなかったので、コンテンツが重すぎない限りは問題ないかなとも思います。

TensorFlow Lite Unity Pluginを利用してDeepLabを動かすまで

概要

最近、ARでSemantic Segmentaionを試そうと色々やっているのでそのメモです。
最終的には拾ってきたDeepLabモデルを変換して実際に動かすまでをやってみようと思います。

今回の記事はこちらを参考にさせていただきました。

asus4.hatenablog.com

元々TensorFlow LiteにはUnity Pluginがあるようで、上記ブログではそれを利用してUnityのサンプルを作ってるみたいです。
サンプル自体もTensorFlowにあるものを移植したもののようです。

公開されているUnityのサンプルプロジェクトは以下です。
今回の記事もこれをベースに色々調べてみたものをまとめたものです。

github.com

今回試したものは、以下のような感じで動作します。

youtu.be



全体を概観する

まずはTensorFlowについて詳しくない人もいると思うのでそのあたりから書きたいと思います。

TensorFlow Liteとは

いきなりサンプルの話にはいる前に、ざっくりとTensorFlow Liteについて触れておきます。

TensorFlow Liteとは、TensorFlowをモバイルやIoTデバイスなど比較的非力なマシンで動作させることを目的に作られたものです。
つまり裏を返せば、通常のTensorFlowはモバイルで動作させるには重くて向いていないということでもあります。

詳細についてはTensorFlow Liteのサイトをご覧ください。

www.tensorflow.org

こちらの記事も参考になります。

note.com

TensorFlow Liteのモデル

TensorFlow Liteのモデルは.tfliteという拡張子で提供され、主に、TensorFlow本体で作られた(訓練された)データを変換することで得ることができます。サイトから引用すると以下のように説明されています。

To use a model with TensorFlow Lite, you must convert a full TensorFlow model into the TensorFlow Lite format—you cannot create or train a model using TensorFlow Lite. So you must start with a regular TensorFlow model, and then convert the model.

特にこの「you must convert a full TensorFlow model into the TensorFlow Lite format.」というところからも分かるように、TensorFlow LiteではTensorFlowのすべての機能を使えるわけではないということです。
逆に言えば、軽量化・最適化を施して機能を削ることでエッジデバイススマホやIoTデバイス)でも高速に動作するというわけなんですね。

DeepLabとは

今回試したのは『DeepLab』と呼ばれる、Googleが提案した高速に動作するセマンティックセグメンテーションのモデル(ネットワーク)です。
セマンティックセグメンテーションについては後日書く予定ですが、こちらのPDFがだいぶ詳しく解説してくれているのでそちらを見るとより理解が深まるかと思います。

DeepLabについてはGoogleのブログでも記載があります。

ai.googleblog.com

ものすごくざっくり言うと、セマンティックセグメンテーションとは、画像の中にある物体を認識しその物体ごとに色分けして分類する、という手法です。
手法自体はいくつか提案されており、今回使用したDeepLabもそうした手法のうちのひとつです。

特徴としてはモバイルでも動作するほど軽く、高速に動くという点です。
そのためARで利用することを考えた場合に、採用する最有力候補になります。(なので今回試した)

ちなみにGitHubにも情報が上がっています。

github.com

ひとまず、前提知識として知っておくことは以上となります。
次からは実際にUnityで動かす過程を通して、最終的には拾ってきたDeepLabモデルを変換して実際に動かすまでをやってみようと思います。

Unityサンプルを動かす

Unityのサンプルを動かしてみましょう。
まずはGitHubからダウンロードしてきたものをそのまま実行してみます。

youtu.be

動画を見てもらうと自転車と車に対して反応しているのが分かるかと思います。

サンプルプロジェクトに含まれるモデルはこちらで配布されているものです。

スターターモデルと書いてあるように、若干精度は低めかなという印象です。
本来的にはここから自分の目的にあったものにさらに訓練していくのだと思います。

サンプル以外のモデルを試す

サンプルに含まれているモデルが動くことが確認できました。
次は別のモデルを試してみることにします。

モデルを変換する

さて、別のモデルといってもイチから訓練するものではなく、すでに訓練され公開されているものを探してきてそれを変換したものを使ってみたいと思います。
これによってサンプルとの違いができ、より深く動作を確認できると思います。

ということで参考にさせてもらったのが以下のリポジトリです。

github.com

ここのREADMEの下の方にあるリンクからモデルをダウンロードしてきます。

f:id:edo_m18:20200807135925j:plain

TensorFlow Liteのモデルへの変換について

TensorFlowで訓練したモデルをLite版へと変換していきます。

ちなみにTensorFlowのモデルにはいくつかの形式があり、以下の形式が用いられます。

  • SavedModel
  • Frozen Model
  • Session Bundle
  • Tensorflow Hub モジュール

これらのモデルについては以下のブログを参照ください。

note.com

今回はこの中の『Frozen Model』で保存されたものを扱います。

ちなみにもし自分で訓練データを作成する場合は注意しなければならない点があります。

前節で書いたように、TensorFlow Liteはいくつかの機能を削減して軽量化を図るものです。
そのため、TensorFlowでは問題なく動いていたものが動かないことも少なくありません。

このあたりについてはまだ全然詳しくないのでここでは解説しません。(というかできません)

以下の記事が変換についてとても詳しく書かれているので興味がある方は参照してみてください。
ちなみに以下の記事は「量子化(Quantization)」を目的とした変換に焦点を絞っています。

qiita.com


以下、余談。

量子化(Quantization)とは

上記記事では以下のように説明されています。

上記のリポジトリから Freeze_Graph (.pb) ファイルをダウンロードします。 ココでの注意点は ASPP などの特殊処理が入ったモデルは軒並み量子化に失敗しますので、なるべくシンプルな構造のモデルに限定して取り寄せることぐらいです。

※ ... ちなみにASPP - Atrous Spatial Pyramid Poolingの略です。

ASPPがなにかを調べてみたら以下の記事を見つけました。

37ma5ras.blogspot.com

記事には以下のように記載されています。

  1. 画面いっぱいに写っていようと画面の片隅に写っていようと猫は猫であるように,image segmentationではscale invarianceを考慮しなければならない. 著者はAtrous Convolutionのdilationを様々に設定することでこれに対処している(fig.2). 著者はこの技法を”atrous spatial pyramid pooling”(ASPP)と呼んでいる.

さらに該当記事から画像を引用させてもらうと、

f:id:edo_m18:20200807124055p:plain

おそらくこの画像の下の様子がピラミッドに見えることからこう呼んでいるのだと思います。
が、量子化に際してこの技法が含まれていると変換できないようなのでここではあまり深堀りしません。

簡単に量子化について触れておくと、以下の記事にはこう説明されています。

情報理論における量子化とは、アナログな量を離散的な値で近似的に表現することを指しますが、本稿における量子化は厳密に言うとちょっと意味が違い、十分な(=32bitもしくは16bit)精度で表現されていた量を、ずっと少ないビット数で表現することを言います。

ニューラルネットワークでは、入力値とパラメータから出力を計算するわけですが、それらは通常、32bitもしくは16bit精度の浮動小数点(の配列)で表現されます。この値を4bitや5bit、もっと極端な例では1bitで表現するのが量子化です。1bitで表現する場合は二値化(binarization)という表現がよく使われますが、これも一種の量子化です。

量子化には、計算の高速化や省メモリ化などのメリットがあります。

developer.smartnews.com

要するに、本来はintfloatなど「大きな容量(16bit ~ 32bit)」を使うものをより小さい容量でも同じような精度を達成する方法、ということでしょうか。
これにはメモリ的なメリットや、単純に演算数の削減が見込めそうです。

モバイルではとにかくこのあたりの最適化は必須になるので量子化は必須と言ってもいいと思います。
(そういう意味で、量子化モデルの変換記事はめちゃめちゃ濃いので一度目を通しておくといいと思います)

が、後半で説明するtfliteへの変換でQuantizationのパラメータを設定するとうまく動作しなかったのでさらなる調査が必要そうです・・・。


閑話休題

モデルの入力・出力を調べる

ではモデルを変換していきましょう。
モデルを変換するためにはいくつかの情報を得なければなりません。

ここでは詳細は割愛しますが、ディープラーニングではニューラルネットワークへの入力出力(Input / Output)が大事になってきます。

ものすごくざっくり言えば、ディープラーニングはひとつの関数です。
プログラムでも、関数を利用したい場合はその引数と戻り値がなにかを知る必要があることに似ています

自分でネットワークを構築し訓練したものであればすでに知っている情報かもしれませんが、だいたいの場合はどこからか落としてきたモデルや、あるいは誰かの構築したネットワークを利用するというケースがほとんどでしょう。

そのため入力と出力を調べる必要があります。
調べ方については上のほうで紹介したこちらの記事が有用です。

そこで言及されていることを引用させていただくと、

INPUTは Input、 形状と型は Float32 [?, 256, 256, 3]、 OUTPUTは ArgMax、 形状と型は Float32 [?, 256, 256] のようです。 なお一見すると ExpandDims が最終OUTPUTとして適切ではないか、と思われるかもしれませんが、 実は Semantic Segmentation のモデルにほぼ共通することですが ArgMax を選定すれば問題ありません。

とあります。
ここで言及されている名前はネットワーク次第なので毎回これになるとは限りません。
しかし

実は Semantic Segmentation のモデルにほぼ共通することですが ArgMax を選定すれば問題ありません。

というのはとても大事な点なので覚えておきましょう。

ちなみにモデルの中身を可視化するツールがあります。
以下の『Netron』というものです。

Webサービスを利用して見てもいいですし、アプリもあるのでよく使う場合はインストールしておいてもいいでしょう。

lutzroeder.github.io

では先ほど落としてきたモデルを読み込ませて見てみましょう。

f:id:edo_m18:20200807151419j:plain

入力はInputという名前で、入力の型はfloat32、入力のサイズは256 x 2563チャンネルというのが分かります。
とても簡単に可視化できるのでとてもオススメのツールです。

続けて出力も見てみましょう。

f:id:edo_m18:20200807151733j:plain

参考にした記事に習ってArgMaxの部分を見てみます。
確認すると出力はArgMaxという名前でint32型、256 x 256の出力になるようです。

モデルの変換には「tflite_convert 」コマンドを使う

情報が揃ったのでダウンロードしてきたモデルを変換します。

変換にはtflite_convertコマンドを利用します。

今回の変換では以下のように引数を指定しました。

tflite_convert ^
  --output_file=converted_frozen_graph.tflite ^
  --graph_def_file=frozen_inference_graph.pb ^
  --input_arrays=Input^
  --output_arrays=ArgMax ^
  --input_shapes=1,256,256,3 ^
  --inference_type=FLOAT ^
  --mean_values=128 ^
  --std_dev_values=128

引数に指定している--input_arrays--output_arraysが、先ほど調べた名前になっているのが分かります。
さらに--input_shapesには入力の形として先ほど調べた256 x 256を指定しています。
これを指定することで適切に変換することができます。

では変換されたモデルをNetronで可視化してみましょう。

f:id:edo_m18:20200807185434p:plain

内容が変化しているのが分かります。
これをUnityのサンプルプロジェクトに入れて利用してみます。

f:id:edo_m18:20200807165038j:plain

これであとは実行するだけ・・・にはいきません。

drive.google.com ※ 変換したモデルを念の為公開しておきます。

モデルに合わせてC#を編集する

実はサンプルで用意されているDeepLabスクリプトはサンプルに含まれているモデルに合わせて実装されているため、今回のケースの場合は少し修正をしなければなりません。

なのでNetronによって確認できる型や形状に定義を変更します。
ということでDeepLabクラスを修正します。

今回修正した内容のdiffは以下です。

f:id:edo_m18:20200807195549p:plain

さて、さっそく変換したモデルを読み込ませて使ってみましょう。
・・・と勢い込んでビルドしてみるものの動かず。Logcatで見てみると以下のようなエラーが表示されていました。

08-07 10:50:49.160: E/Unity(6127): Unable to find libc
08-07 10:50:49.163: E/Unity(6127): Following operations are not supported by GPU delegate:
08-07 10:50:49.163: E/Unity(6127): ARG_MAX: Operation is not supported.
08-07 10:50:49.163: E/Unity(6127): BATCH_TO_SPACE_ND: Operation is not supported.
08-07 10:50:49.163: E/Unity(6127): SPACE_TO_BATCH_ND: Operation is not supported.
08-07 10:50:49.163: E/Unity(6127): 53 operations will run on the GPU, and the remaining 31 operations will run on the CPU.
08-07 10:50:49.163: E/Unity(6127): TensorFlowLite.Interpreter:TfLiteInterpreterCreate(IntPtr, IntPtr)
08-07 10:50:49.163: E/Unity(6127): TensorFlowLite.Interpreter:.ctor(Byte[], InterpreterOptions)
08-07 10:50:49.163: E/Unity(6127): TensorFlowLite.BaseImagePredictor`1:.ctor(String, Boolean)
08-07 10:50:49.163: E/Unity(6127): TensorFlowLite.DeepLab:.ctor(String, ComputeShader)
08-07 10:50:49.163: E/Unity(6127): DeepLabSample:Start()

いくつかのオペレーションがGPUに対応していないためのエラーのようです。
ということで、以下の部分をfalseにして(GPU未使用にして)ビルドし直してみます。
ちなみにiOSではGPUオンの状態でも問題なく動いたのでAndroid版の問題のようです

public DeepLab(string modelPath, ComputeShader compute) : base(modelPath, true)
{
    // ... 後略

これを、

public DeepLab(string modelPath, ComputeShader compute) : base(modelPath, false)
{
    // ... 後略

こうします。
この第二引数がGPUを使うかどうかのフラグの指定になっています。

あとはこれをビルドし直して動かすだけです。

youtu.be

(すでに冒頭でも載せていますが)これで無事に動きました!

最後に

これがゴールではなくむしろスタート地点です。
ここから、独自のモデルの訓練をして目的に適合する結果を得られるように調整していかなければなりません。

が、ひとまずはTensorFlowで作られたモデルをtflite形式に変換して動作させるところまで確認できたので、あとはこの上に追加して作業をしていく形になります。

ここまで来るのは長かった・・・。
やっと本題に入れそうです。

Oculus Questでハンドトラッキングを使ってみる

概要

Oculus Questのハンドトラッキングが利用できるようになったので使ってみたいと思います。
そこで、実際に使用するにあたってセットアップ方法とどういう情報が取れるのか、どう使えるかなどをまとめておきます。

ドキュメントは以下です。

developer.oculus.com

ちなみにドキュメントにも注意書きが書かれていますが、Oculus Linkでのハンドトラッキングは開発用でのみ動作するようです。

注:Oculus QuestとOculus Linkを使用する場合、PC上でのハンドトラッキングの使用はUnityエディターでサポートされています。この機能は、Oculus Quest開発者の反復時間短縮のため、Unityエディターでのみサポートされています。


Table of Contents


セットアップ

まずはプロジェクトをセットアップしていきます。

ここはHand Gesture用ではなく普通のOculusのセットアップです。
なのですでに知っている方は読み飛ばしてもらって大丈夫です。

XR Managementのインストール

Oculusを利用する場合は、Package ManagerからXR Managementをインストールします。

f:id:edo_m18:20200626113110p:plain

インストールすると、Project SettingsにXR Managementの項目が追加されるので、そこからOculus用のPluginをインストールします。

f:id:edo_m18:20200626113336p:plain

こちらもインストールが終わると画面が以下のように変化するので、Plugin ProvidersOculus Loaderを追加します。

f:id:edo_m18:20200626113628p:plain

※ ちなみに、Unityエディタ上でOculus Linkを使って開発を行う場合はPC向けにもOculus Loaderを設定する必要があります。

Oculus Integrationをインポート

次に、Unity Asset StoreからOculus Integrationをインポートします。

インポートが終わったらOVRCameraRigをシーン内に配置します。
その際、シーン内のカメラと重複するので元々あったほうを削除します。

f:id:edo_m18:20200626115826p:plain

配置したらOVRCameraRigにアタッチされているOVR ManagerHand Tracking SupportのリストからControllers and Handsを選択します。

f:id:edo_m18:20200626120427p:plain

※ ちなみにPlatformの設定がAndroidなっていないと設定できないようなので注意してください。

ハンドトラッキングを有効にする

ハンドトラッキングの機能を利用するためにはOculus Quest本体側の設定も必要になります。
ドキュメントから引用すると以下のように設定します。

ユーザーが仮想環境で手を使用するには、Oculus Quest上でハンドトラッキング機能を有効にする必要があります。

Oculus Questで、[Settings(設定)] > [Device(デバイス)]に移動します。 トグルボタンをスライドすることによってハンドトラッキング機能を有効にします。 手とコントローラーの使用を自動で切り替えられるようにするには、トグルボタンをスライドすることによって手またはコントローラーの自動有効化機能を有効にします。

シーンへの手の追加

ドキュメントによると、

手を入力デバイスとして使用するには、手動でシーンに追加する必要があります。手はOVRHandPrefab prefaにより実装されています。

とのことなので、対象のPrefabを配置します。
対象のPrefabはOculus/VR/Prefabsにあります。

それを、シーンに配置したOVRCameraRig以下のLeftHandAnchorRightHandAnchorの下に配置します。

f:id:edo_m18:20200626121003p:plain

手のタイプを設定する

配置したOVRHandPrefabは両手用になっているので、以下の3つのコンポーネントの設定を適切な手のタイプ(左手 or 右手)に変更します。

  • OVRHand
  • OVRSkeleton
  • OVRMesh

f:id:edo_m18:20200626121427p:plain

以上でセットアップは終了です。
あとはOculus Linkでつないで再生ボタンを押すと以下のように手が表示されるようになります。

※ バージョンによってはこれで正常に動作しない場合があるかもしれません。その場合はOculus Integrationを最新にしてみてください。

ハンドトラッキングによるデータの取得

セットアップが終わったので、あとはハンドトラッキングシステムから得られるデータを用いて様々なコンテンツを作っていくことができます。
ここではいくつかのデータの取得方法をまとめておこうと思います。

各指のピンチ強度を測る

OVRHandには各指のピンチ強度(*)を測るAPIがあるので簡単に測ることができます。

  • ... ピンチ強度は各指が『親指とどれくらい近づいているかを測る値』です。曲がり具合ではないので注意です。
float thumbStr = _rightHand.GetFingerPinchStrength(OVRHand.HandFinger.Thumb);
float indexStr = _rightHand.GetFingerPinchStrength(OVRHand.HandFinger.Index);
float middleStr = _rightHand.GetFingerPinchStrength(OVRHand.HandFinger.Middle);
float ringStr = _rightHand.GetFingerPinchStrength(OVRHand.HandFinger.Ring);
float pinkyStr = _rightHand.GetFingerPinchStrength(OVRHand.HandFinger.Pinky);

このあたりのデータを組み合わせれば、簡単なジェスチャーなどは検知できそうですね。

ボーン情報を取得する

ボーンの各情報を得るためにはOVRSkeletonクラスを利用します。
OVRSkeletonにはOVRBone構造体を保持するリストがあり、そこから情報を取り出します。

リストのどこにどのボーン情報が入っているかはOVRSkeleton.BoneId enumによって定義されており、それをintに変換して利用します。

なお、どこにどの情報が入っているかはドキュメントに記載されています。引用すると以下のように定義されています。

Invalid          = -1
Hand_Start       = 0
Hand_WristRoot   = Hand_Start + 0 // root frame of the hand, where the wrist is located
Hand_ForearmStub = Hand_Start + 1 // frame for user's forearm
Hand_Thumb0      = Hand_Start + 2 // thumb trapezium bone
Hand_Thumb1      = Hand_Start + 3 // thumb metacarpal bone
Hand_Thumb2      = Hand_Start + 4 // thumb proximal phalange bone
Hand_Thumb3      = Hand_Start + 5 // thumb distal phalange bone
Hand_Index1      = Hand_Start + 6 // index proximal phalange bone
Hand_Index2      = Hand_Start + 7 // index intermediate phalange bone
Hand_Index3      = Hand_Start + 8 // index distal phalange bone
Hand_Middle1     = Hand_Start + 9 // middle proximal phalange bone
Hand_Middle2     = Hand_Start + 10 // middle intermediate phalange bone
Hand_Middle3     = Hand_Start + 11 // middle distal phalange bone
Hand_Ring1       = Hand_Start + 12 // ring proximal phalange bone
Hand_Ring2       = Hand_Start + 13 // ring intermediate phalange bone
Hand_Ring3       = Hand_Start + 14 // ring distal phalange bone
Hand_Pinky0      = Hand_Start + 15 // pinky metacarpal bone
Hand_Pinky1      = Hand_Start + 16 // pinky proximal phalange bone
Hand_Pinky2      = Hand_Start + 17 // pinky intermediate phalange bone
Hand_Pinky3      = Hand_Start + 18 // pinky distal phalange bone
Hand_MaxSkinnable= Hand_Start + 19
// Bone tips are position only. They are not used for skinning but are useful for hit-testing.
// NOTE: Hand_ThumbTip == Hand_MaxSkinnable since the extended tips need to be contiguous
Hand_ThumbTip    = Hand_Start + Hand_MaxSkinnable + 0 // tip of the thumb
Hand_IndexTip    = Hand_Start + Hand_MaxSkinnable + 1 // tip of the index finger
Hand_MiddleTip   = Hand_Start + Hand_MaxSkinnable + 2 // tip of the middle finger
Hand_RingTip     = Hand_Start + Hand_MaxSkinnable + 3 // tip of the ring finger
Hand_PinkyTip    = Hand_Start + Hand_MaxSkinnable + 4 // tip of the pinky
Hand_End         = Hand_Start + Hand_MaxSkinnable + 5
Max              = Hand_End + 0

なお、以下の記事から画像を引用させていただくと、各IDの割り振りはこんな感じになるようです。

Finger's ID

qiita.com

簡単なハンドコントロール

OVRHandは、ホーム画面などで利用される手と同じ機能を簡単に利用するためのAPIを提供してくれています。
その情報にアクセスするにはOVRHand.PointerPoseプロパティを利用します。

これはTransform型で、手をポインタとして見た場合の位置と回転を提供してくれます。

視覚化してみたのが以下の動画です。

youtu.be

個人的にはやや直感に反する挙動だなと思っています。
手の指の向きは参考にされていないようで、基本的にポインタ方向は『手の高さ』によって算出されているような印象を受けます。

手のひらの法線を計算する

今回、個人プロジェクトで『手のひらの法線』が必要になり、それを求めるプログラムを書いたので参考までに載せておきます。

using System.Collections;
using System.Collections.Generic;

using UnityEngine;

namespace Conekton.ARUtility.Input.Application
{
    public class HandPoseController : MonoBehaviour
    {
        [SerializeField] private OVRHand _targetHand = null;
        [SerializeField] private OVRSkeleton _rightSkeleton = null;
        [SerializeField] private Transform _palmNormalTrans = null;
        [SerializeField] private float _detectLimit = 0.5f;

        private OVRSkeleton.BoneId[] _forPalmCalcTargetList = new[]
        {
            // First two of them are used for calculating palm normal.
            OVRSkeleton.BoneId.Hand_Index1,
            OVRSkeleton.BoneId.Hand_Pinky0,

            OVRSkeleton.BoneId.Hand_Middle1,
            OVRSkeleton.BoneId.Hand_Ring1,
            OVRSkeleton.BoneId.Hand_Pinky1,
            OVRSkeleton.BoneId.Hand_Thumb0,
        };

        private OVRSkeleton.BoneId BoneIDForNormalCalculation1 => OVRSkeleton.BoneId.Hand_Index1;
        private OVRSkeleton.BoneId BoneIDForNormalCalculation2 => OVRSkeleton.BoneId.Hand_Pinky0;

        public bool TryGetPositionAndNormal(out Vector3 position, out Vector3 normal)
        {
            if (!_targetHand.IsTracked)
            {
                position = Vector3.zero;
                normal = Vector3.up;
                return false;
            }

            Vector3 center = Vector3.zero;

            foreach (var id in _forPalmCalcTargetList)
            {
                OVRBone bone = GetBoneById(id);
                center += bone.Transform.position;
            }

            center /= _forPalmCalcTargetList.Length;

            position = center;

            OVRBone bone0 = GetBoneById(BoneIDForNormalCalculation1);
            OVRBone bone1 = GetBoneById(BoneIDForNormalCalculation2);

            Vector3 edge0 = bone0.Transform.position - center;
            Vector3 edge1 = bone1.Transform.position - center;

            normal = Vector3.Cross(edge0, edge1).normalized;

            return true;
        }

        private OVRBone GetBoneById(OVRSkeleton.BoneId id)
        {
            return _rightSkeleton.Bones[(int)id];
        }
    }
}

考え方はシンプルで、ハンドトラッキングから得られるボーンの位置をいくつか選び、それらの平均の位置を手のひらの位置としています。
また法線については、求めた手のひらの位置とふたつのボーンの位置との差分ベクトルを取り、それの外積を取ることで求めています。

まとめ

Oculus Questのハンドトラッキングの精度は驚異と言っていいと思います。

過去に、Leapmotionを使ったコンテンツを開発したことがありますが、Leapmotionよりも精度が高い印象です。(それ用デバイスより精度が高いって・・)
実際に体験すると、本当にVR内に自分の手があるかのように思えるくらいなめらかにトラッキングしてくれます。

ハンドトラッキングを用いたUI/UXはさらに発展していくと思うので、これからとても楽しみです。

Unityの推論エンジン『Barracuda』を試してみたのでそのメモ

概要

以下の記事を参考に、最近リリースされたUnity製推論エンジンを試してみたのでそのメモです。

qiita.com

note.com

Barracudaとは?

Barracudaとは、ドキュメントにはこう記載されています。

Barracuda is lightweight and cross-platform Neural Net inference library. Barracuda supports inference both on GPU and CPU.

軽量でクロスプラットフォームニューラルネットワークの推論ライブラリということですね。
そしてこちらのブログによるとUnity製のオリジナルだそうです。

セットアップ

BarracudaPackage Managerから簡単にインストールできます。
インストールするにはWindow > Package ManagerからBarracudaを選択してインストールするだけです。

f:id:edo_m18:20200706131221p:plain

モデルの準備

これでC#からBarracudaを利用する準備は終わりです。
しかし、Deep Learningを利用して処理を行うためには訓練済みのモデルデータが必要です。
これがなければDeep Learningはなにも仕事をしてくれません。

Barracudaで扱えるモデル

Deep Learningを利用するためのフレームワークは多数出ています。有名どころで言えばTensorFlowなどですね。
そしてこうしたフレームワークごとにフォーマットがあり、そのフレームワークで訓練したデータは独自の形式で保存されます。
つまり言い換えればフレームワークごとにデータフォーマットが異なるということです。

そしてBarracudaでもそれ用のデータ・フォーマットで保存されたモデルデータが必要になります。
しかしBarracudaではONNX形式のモデルデータも扱うことができるようになっています。

ONNXとは?

ONNXはOpen Neural Network eXchangeの略です。(ちなみに『オニキス』と読むらしいです)
Openの名がつく通り、Deep Learningのモデルを様々なフレームワーク間で交換するためのフォーマット、ということのようです。

前述の通りBarracudaでも利用できます。

ONNXモデルの配布サイト

以下のリポジトリからいくつかの学習済みモデルがDownloadできます。

github.com

モデルの変換

Barracudaでは主要なフレームワークのモデルからBarracuda形式およびONNX形式に変換するためのツールを提供してくれています。 

note.com

詳細は上記の記事を見てもらいたいと思いますが、どういう感じで変換を行うのかのコマンド例を載せておきます。

$ python tensorflow_to_barracuda.py ../mobilenet_v2_1.4_224_frozen.pb ../mobilenet_v2.nn

Converting ../mobilenet_v2_1.4_224_frozen.pb to ../mobilenet_v2.nn
Sorting model, may take a while...... Done!
IN: 'input': [-1, -1, -1, -1] => 'MobilenetV2/Conv/BatchNorm/FusedBatchNorm'
OUT: 'MobilenetV2/Predictions/Reshape_1'
DONE: wrote ../mobilenet_v2.nn file.

モデルを利用して推論する(Style Chnage)

冒頭で紹介したこちらの記事からStyle Changeの方法を見ていきます。
(これがおそらく一番短くて分かりやすい例だと思います)

using UnityEngine;
using Unity.Barracuda;

public class StyleChange : MonoBehaviour
{
    [SerializeField] private NNModel _modelAsset = null;
    [SerializeField] private RenderTexture _inputTexture = null;
    [SerializeField] private RenderTexture _outputTeture = null;

    private Model _runtimeModel = null;
    private IWorker _worker = null;

    private void Start()
    {
        // Load an ONNX model.
        _runtimeModel = ModelLoader.Load(_modelAsset);

        // Create a worker.
        // WorkerFactory.Type means which CPU or GPU prefer to use.
        _worker = WorkerFactory.CreateWorker(WorkerFactory.Type.Compute, _runtimeModel);
    }

    private void Update()
    {
        Tensor input = new Tensor(_inputTexture);
        Inference(input);
        input.Dispose();
    }

    private void Inference(Tensor input)
    {
        _worker.Execute(input);
        Tensor output = _worker.PeekOutput();
        output.ToRenderTexture(_outputTeture, 0, 0, 1/255f, 0, null);
        output.Dispose();
    }

    private void OnDestroy()
    {
        _worker?.Dispose();
    }
}

だいぶ短いコードですね。これで映像を変換できるのだから驚きです。
ただし、複雑なネットワークを使っている場合はその分処理が重くなるので、リアルタイムなポストエフェクトとしては利用できないと思います。

このコードを実行すると以下のような結果が得られます。
(もちろん、設定するモデルによって出力の絵は変わります)

f:id:edo_m18:20200724101348p:plain

InputにRenderTextureを与え、OutputもRenderTextureで受け取っていますね。
入出力ともに画像なので利用イメージがしやすいと思います。

しかし(書いておいてなんですが)こうしたスタイルの変更というのはゲームではあまり使用されないかもしれません。
それよりも、物体検知やセグメンテーションなどでその真価を発揮するのではないかなと思っています。

ということで、次は物体検知についても書いておきます。

モデルを利用して推論する(物体検知)

以下のコードはこちらのGitHubのものを参考にさせていただいています。
ファイルへの直リンク

using System;
using Barracuda;
using System.Linq;
using UnityEngine;
using System.Collections;
using System.Collections.Generic;
using System.Text.RegularExpressions;

public class Classifier : MonoBehaviour
{
    public NNModel modelFile;
    public TextAsset labelsFile;

    public const int IMAGE_SIZE = 224;
    private const int IMAGE_MEAN = 127;
    private const float IMAGE_STD = 127.5f;
    private const string INPUT_NAME = "input";
    private const string OUTPUT_NAME = "MobilenetV2/Predictions/Reshape_1";

    private IWorker worker;
    private string[] labels;


    public void Start()
    {
        this.labels = Regex.Split(this.labelsFile.text, "\n|\r|\r\n")
            .Where(s => !String.IsNullOrEmpty(s)).ToArray();
        var model = ModelLoader.Load(this.modelFile);
        this.worker = WorkerFactory.CreateWorker(WorkerFactory.Type.ComputePrecompiled, model);
    }


    private int i = 0;
    public IEnumerator Classify(Color32[] picture, System.Action<List<KeyValuePair<string, float>>> callback)
    {
        var map = new List<KeyValuePair<string, float>>();

        using (var tensor = TransformInput(picture, IMAGE_SIZE, IMAGE_SIZE))
        {
            var inputs = new Dictionary<string, Tensor>();
            inputs.Add(INPUT_NAME, tensor);  
            var enumerator = this.worker.ExecuteAsync(inputs);

            while (enumerator.MoveNext())
            {
                i++;
                if (i >= 20)
                {
                    i = 0;
                    yield return null;
                }
            };

            // this.worker.Execute(inputs);
            // Execute() scheduled async job on GPU, waiting till completion
            // yield return new WaitForSeconds(0.5f);

            var output = worker.PeekOutput(OUTPUT_NAME);

            for (int i = 0; i < labels.Length; i++)
            {
                map.Add(new KeyValuePair<string, float>(labels[i], output[i] * 100));
            }
        }

        callback(map.OrderByDescending(x => x.Value).ToList());
    }


    public static Tensor TransformInput(Color32[] pic, int width, int height)
    {
        float[] floatValues = new float[width * height * 3];

        for (int i = 0; i < pic.Length; ++i)
        {
            var color = pic[i];

            floatValues[i * 3 + 0] = (color.r - IMAGE_MEAN) / IMAGE_STD;
            floatValues[i * 3 + 1] = (color.g - IMAGE_MEAN) / IMAGE_STD;
            floatValues[i * 3 + 2] = (color.b - IMAGE_MEAN) / IMAGE_STD;
        }

        return new Tensor(1, height, width, 3, floatValues);
    }
}

ちなみにStartメソッドで読み込んでいるテキストは分類の名称が改行区切りで入っているただのテキストファイルです。
以下はその一部。(ファイルへの直リンク

background
tench, Tinca tinca
goldfish, Carassius auratus
great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
tiger shark, Galeocerdo cuvieri
hammerhead, hammerhead shark
electric ray, crampfish, numbfish, torpedo
stingray
cock
hen
...

入力画像から識別したラベルに対する確率を取得する

これらの動作の基本的な流れは以下です。

  1. 画像をテンソル化して入力データとする
  2. ニューラルネットワークを通して出力を得る
  3. 出力はラベル数と同じ数の1階テンソルで、値はそれぞれの分類の確度(%)
  4. あとは確度に応じて望む処理をする

という具合です。
そしてその出力部分を抜粋すると以下のようになっています。

var output = worker.PeekOutput(OUTPUT_NAME);

for (int i = 0; i < labels.Length; i++)
{
    map.Add(new KeyValuePair<string, float>(labels[i], output[i] * 100));
}

output[i] * 100としている箇所が確率を%に変換している部分ですね。
つまり、該当番号(i)のラベルである確率をDictionary型の変数に入れて返しているというわけです。

取得するテンソルvar output = worker.PeekOutput(OUTPUT_NAME);でアクセスしています。
そしてこのOUTPUT_NAMEprivate const string OUTPUT_NAME = "MobilenetV2/Predictions/Reshape_1";と定義されています。

この文字列は訓練されたモデルのネットワークの変数名でしょう。
つまり推論の結果をこれで取得している、というわけですね。

モデルのInput / Outputを確認する

前述したように、モデルへのInputとOutputを明示的に指定する必要があり、そのための文字列を知る必要があります。

そのための情報はインスペクタから確認することができます。
インポートしたモデルファイルを選択すると以下のような情報が表示されます。

f:id:edo_m18:20200724100116p:plain

この図のInputs (1)Outputs (1)がそれに当たります。
その下のLayersニューラルネットワークのレイヤーの情報です。
つまりどんなネットワークなのか、ということがここで見れるわけですね。

画像として出力されるモデルかどうか確認する

ちなみにStyle Changeのところで書いたように、画像自体を変換して出力するネットワークなのかどうかは以下のInputs / Outputsを確認すると分かると思います。

f:id:edo_m18:20200724101710p:plain

入力が画像サイズで、出力も画像サイズになっている場合はそれが画像として出力されていると見ることが出来ます。
出力が[1, 224, 224, 3]となっているのは224 x 224サイズ3チャンネル(RGB)1つ出力することを意味しているわけですね。

まとめ

Deep Learningニューラルネットワーク)は、極論で言えば巨大な関数であると言えると思います。
入力となる引数も巨大であり、そこから目的の出力を得ること、というわけですね。


y = f(x)

の、 xが入力(つまり今回の場合は画像)で、その結果(どのラベルの確率が高いか)が y、というわけです。

これを上記の物体検知に当てはめてC#で書き直すと、

// inferenceは「推論」を意味する英単語
Dictionaly<string, float> map = Inference(_inputTexture);

という感じですね。

最近はDeep Learningについてずっと調査をしています。

そこでの自分の大まかな理解を書いておくと、この巨大な関数を解くためのパラメータを『機械学習』で調整させる。
その調整する方法が『パーセプトロン』をベースとする『ニューラルネットワーク』を利用して行っている。

そしてこの調整されたパラメータとネットワーク構成がつまりは訓練済みモデル(データ)というわけです。
なのでそのパラメータを利用した関数を通すとなにかしらの意味がある出力が得られる、というわけなんですね。

もちろん、そのネットワークをどう組むか。それがどう実現しているのか。基礎を学ぼうとすると膨大な知識が必要になります。

しかし、こと利用する視点だけで見ればなんのことはない、ただの関数実行だ、と見ると利用がしやすいのではないかと思います。

これを機に色々とUnity上でディープラーニングを使って色々とやっていきたいと思います。

ComputeShaderで動かす様々な形に変化するパーティクルシステム

概要

今回は1/7〜1/10にかけて開催されたCES2020でMESONが展示した『PORTAL』の中で、メインの演出となったパーティクルシステムについて書きたいと思います。

このパーティクルシステムの特徴は、指定したモデルグループの頂点位置にパーティクルをまとわせるというものです。パーティクルの色は対象モデルの頂点位置と同じ色になるように調整されています。

以下の動画で、どんな動きをするかを見ることができます。

youtu.be

今回の記事の内容はGitHubにもアップしてあるので実際の動作を見たい方はそちらもご覧ください。

github.com

サンプルプロジェクトを実行すると以下のようなシーンを見ることができます。
(※ 動画内のモデルはUnity Asset Storeのものなのでリポジトリ内には含まれません。モデルは各自でご用意ください


Transform particle system demo

常にシーンに存在しているパーティクルが設定に応じて様々な形に変化します。
メッシュの頂点に張り付いたり、任意の形状になったり、テクスチャの絵の形になったり。

今回はこれを達成するために工夫した点についてまとめます。


Table of Contents


処理フロー

まずはざっくり概観を。

  1. 対象となるモデル(メッシュ)の情報を集める
  2. それをグルーピングする(すべての情報を配列にまとめる)
  3. 初期化用データのためのインデックスバッファを用意する
  4. データをComputeShaderに送る
  5. 必要に応じてComputeShaderでターゲット位置を更新
  6. GPUインスタンシングでまとめて描画する

方針

変化しないデータをひとつにまとめる

最初、モデル(メッシュ)のデータをターゲットごとに都度取り出して処理しようとしたところ単純に負荷が高すぎてうまくいきませんでした。
特に、プロジェクトでは1モデルあたり10,000〜20,000頂点くらいあり、さらにそれを5、6体分グルーピングする必要があったためそれなりの負荷になってしまいました。

ゲーミングPCのような高性能なPCでさえもプチフリーズをするくらいには負荷が高かったです。

そこでメモリアクセスを効率よくするため、またコピーを最適化するため頂点情報をまとめてひとつの配列にしSystem.Array.Copyを利用することで対応しました。(同様にUVなどの値も別途配列にまとめる)

つまり、メモリ効率を意識してすべての情報を配列にまとめるということをしました。
テクスチャもTexture2DArrayという仕組みで配列にまとめてシェーダに送ります。

頂点データも変化しないデータとしてまとめてしまっているのは対象モデルがSkinnedMeshRendererを持っておらずアニメーションしない前提のためです。

アニメーションする場合は毎フレーム頂点位置を更新してやる必要があるため負荷が高いかもしれません。(時間があったらSkinnedMeshRendererにも対応してみたいと思います)

初期化用データを用意してパーティクルのターゲット変更をCompute Shaderで行う

各パーティクルそれぞれは常になにかしらのターゲット位置を持っています。(ランタイムで位置を計算する場合もあります)
またターゲットは、冒頭の動画のように途中で切り替えることができるようになっています。

しかしながら、パーティクル数は膨大になることが多くそれをCPUで切り替えるには重すぎる可能性があります。
そこで本パーティクルシステムでは初期化(ターゲットの変更処理)もGPUで行っています。

イメージ的には、ターゲットが変わった際にグループ単位でごっそり内容を入れ替えてしまうイメージです。

Indexバッファを用いて初期化用データ配列にアクセスする

前述のターゲット変更の仕組みですが、ここにも少し工夫があります。
パーティクルの更新処理としてはスレッドIDを配列の添字にしてアクセスすることが多いと思いますが、今回のシステムではシーンに存在しているパーティクル数とターゲットの数が異なる場合がほとんどです。

つまりパーティクルの数と初期化用データの数が異なるということです。
当然、配列の数が異なるということは同一の添字を利用することはできません。

例えば冒頭の動画を見てもらうと分かるように、移動対象がメッシュのモデルの頂点だけではなく、テクスチャのピクセルなどにもなりえます。

そこで採用したのが『インデックスバッファ』です。

イメージ的には通常のレンダリングパイプラインで利用されるインデックスバッファと同様です。

各パーティクルを初期化する際、InitData構造体の配列を用いてパーティクルの情報を書き換えるようにしているのですが、どのInitData構造体にアクセスすればいいのかはスレッドIDからでは判断できません。
そこでインデックスバッファの出番、というわけです。

IndexBuffer配列はパーティクル数と同じになっていますが、格納されているのは対象となるInitData配列のインデックスです。
このインデックス番号を利用して初期化対象のデータを特定し、初期化を行うというフローになっています。

初期化時のカーネル関数を見てみるとイメージが掴めるかと思います。

[numthreads(THREAD_NUM, 1, 1)]
void SetupParticlesImmediately(uint id : SV_DispatchThreadID)
{
    TransformParticle p = _Particles[id];

    uint idx = _IndexBuffer[id];

    float4x4 mat = _MatrixData[_InitDataList[idx].targetId];

    p.isActive = _InitDataList[idx].isActive;
    p.position = mul(mat, float4(_InitDataList[idx].targetPosition, 1.0)).xyz;
    p.targetPosition = p.position;
    p.uv = _InitDataList[idx].uv;
    p.targetId = _InitDataList[idx].targetId;
    p.scale = _InitDataList[idx].scale;
    p.horizontal = _InitDataList[idx].horizontal;
    p.velocity = _InitDataList[idx].velocity;

    _Particles[id] = p;
}

上記のuint idx = _IndexBuffer[id];の部分がインデックスバッファからインデックスを取得している箇所ですね。

解説

全体を概観できたところで、コードを交えながらどう実装したかを解説していきたいと思います。

Compute Shaderの利用

今回のパーティクルシステムではCompute Shaderを利用しています。

主なCompute Shaderのカーネルは2種類

まずはCompute Shaderについてです。
そもそもCompute Shaderってなんぞ? って人は前に書いた記事も参考にしてみてください。

edom18.hateblo.jp

edom18.hateblo.jp

今回、パーティクルの計算のために用意したのは主に2種類のカーネルです。
後述しますが、それぞれに複数のカーネルがあります。

種類は以下。

  1. パーティクルをターゲットに動かすための更新用カーネル
  2. パーティクルのターゲット自体を変更するための初期化カーネル

の2つです。

構造体

カーネルの説明に入る前に、今回のパーティクルシステムで利用している構造体を整理しておきます。
利用している構造体は以下の2つ。

ひとつめがパーティクル用でふたつめが初期化用です。

public struct TransformParticle
{
    public int isActive;
    public int targetId;
    public Vector2 uv;

    public Vector3 targetPosition;

    public float speed;
    public Vector3 position;

    public int useTexture;
    public float scale;

    public Vector4 velocity;

    public Vector3 horizontal;
}

public struct InitData
{
    public int isActive;
    public Vector3 targetPosition;

    public int targetId;
    public float scale;

    public Vector4 velocity;

    public Vector2 uv;
    public Vector3 horizontal;
}

初期化用のほうは、初期化したい内容だけを定義した構造体になっています。
ターゲット位置とターゲットのID(モデル行列やテクスチャアクセスに利用)、UV値、スケール、そしてアクティブ状態の内容を保持します。

初期化用カーネル

初期化用カーネルでは前述の通り、すでにバインド済みのパーティクルバッファを初期化用構造体によって初期化します。
また前述のように、対象のInitData配列にアクセスするためのインデックスバッファも利用します。

なお、初期化用のカーネルはふたつあります。

ひとつは、ターゲットを更新し、そのターゲットにスムーズに遷移するための初期化(SetupParticles)。
そしてもうひとつは『即座に』パーティクルの状態を変更するための初期化(SetupParticlesImmediately)です。

ひとつめがメインの初期化処理ですが、表現によっては即座にパーティクルを指定の位置に移動させたい場合があります。
その場合に利用するのが2番目のカーネルというわけです。

具体的なコードは以下。

[numthreads(THREAD_NUM, 1, 1)]
void SetupParticles(uint id : SV_DispatchThreadID)
{
    TransformParticle p = _Particles[id];

    uint idx = _IndexBuffer[id];

    float4x4 mat = _MatrixData[_InitDataList[idx].targetId];

    p.isActive = _InitDataList[idx].isActive;
    p.targetPosition = mul(mat, float4(_InitDataList[idx].targetPosition, 1.0)).xyz;
    p.uv = _InitDataList[idx].uv;
    p.targetId = _InitDataList[idx].targetId;
    p.scale = _InitDataList[idx].scale;
    p.horizontal = _InitDataList[idx].horizontal;
    p.velocity = _InitDataList[idx].velocity;

    _Particles[id] = p;
}

[numthreads(THREAD_NUM, 1, 1)]
void SetupParticlesImmediately(uint id : SV_DispatchThreadID)
{
    TransformParticle p = _Particles[id];

    uint idx = _IndexBuffer[id];

    float4x4 mat = _MatrixData[_InitDataList[idx].targetId];

    p.isActive = _InitDataList[idx].isActive;
    p.position = mul(mat, float4(_InitDataList[idx].targetPosition, 1.0)).xyz;
    p.targetPosition = p.position;
    p.uv = _InitDataList[idx].uv;
    p.targetId = _InitDataList[idx].targetId;
    p.scale = _InitDataList[idx].scale;
    p.horizontal = _InitDataList[idx].horizontal;
    p.velocity = _InitDataList[idx].velocity;

    _Particles[id] = p;
}

実は違いは一箇所だけです。p.positionを設定するかしないかです。
位置更新用のカーネルでは、呼び出されるごとに位置を更新するわけですが、ターゲットが存在するパーティクルの場合は今の位置から徐々にp.targetPositionに移動する、という実装になっています。

つまりは、p.positionが即座に変更されることで、見た目上即時、位置が更新されたように見えている、というわけですね。

さて肝心の処理ですが、基本的にはInitData構造体によって渡されたデータをコピーしているだけです。

初期化処理の中でのポイントは_MatrixData部分です。
targetIdによって計算中のパーティクルがどのモデルに属しているのかを判別しており、該当のモデル行列を取り出してターゲット位置に掛けることで実際のモデルの頂点位置を計算しています。

ちなみに感のいい方なら気づいたかもしれませんが、テクスチャのピクセルに変形する場合はモデル行列は必要ありません。そのため、その場合にはただのMaterix4x4.identityを利用しています。

また同様に、レンダリング用シェーダでも対象テクスチャを判別するためにtargetIdが必要になるため、パーティクルのデータとして設定しています。
なお、なぜこれが必要になるのかはレンダリングの詳細のときに解説します。

更新用カーネル

これは特にむずかしいことはしていません。
コードを見てもらうと分かりますが、ターゲット位置と現在位置の差分距離をDeltaTimeによって徐々にターゲット位置に近づけるように計算しているだけです。
(ただ、パーティクルごとにアニメーションの差を生ませるために『速度』の項目がありますが、本質的にはターゲットに近づけていることには変わりありません)

[numthreads(THREAD_NUM, 1, 1)]
void UpdateAsTarget(uint id : SV_DispatchThreadID)
{
    TransformParticle p = _Particles[id];

    float3 delta = p.targetPosition - p.position;
    float3 pos = (delta + p.velocity.xyz * 0.2) * _DeltaTime * p.speed;

    const float k = 5.5;
    p.velocity.xyz -= k * p.velocity.xyz * _DeltaTime;

    p.position += pos;
    p.useTexture = 1;

    _Particles[id] = p;
}

特定の形状に移動させるカーネル

基本的には設定されたモデルの頂点にパーティクルが張り付く動作をしますが、パーティクルの移動先をランタイムに計算することで任意の形状に散らすこともできます。

サンプルの動画では拡散したりぐるぐる周ったりしているのが分かるかと思います。

該当のコードは以下の通り。

[numthreads(THREAD_NUM, 1, 1)]
void UpdateAsExplosion(uint id : SV_DispatchThreadID)
{
    TransformParticle p = _Particles[id];

    float3 pos = (p.velocity.xyz) * p.velocity.w * _DeltaTime;

    float s = sin(rand(id) + _Time) * 0.00003;
    p.velocity.xyz += s;
    float k = 2.0;
    p.velocity.xyz -= k * p.velocity.xyz * _DeltaTime;

    p.position += pos;

    p.useTexture = 0;

    _Particles[id] = p;
}

パーティクルシステムでは更新方法のカーネルを変えることでパーティクルの位置計算の仕方を変更しています。
つまり、カーネルを増やせば様々な形状の変化を与えることができるわけです。

なので例では四散するパーティクルの更新用カーネルを示しましたが、サンプルの動画ではぐるぐると周る演出もあります。
これは更新用カーネルを変えることによって実現しています。

この更新用カーネルを増やすことでさらに様々な表現を作ることが可能になっています。

C#コード

次にC#側のコードを見ていきましょう。

今回のシステムで使う主なクラスは以下の3つです。

  • TransformParticleSystem
  • ParticleTarget
  • ParticleTargetGroup

各クラスの概要は以下の通りです。

TransformParticleSystem

TransformParticleSystemが今回のコア部分です。このクラスでCompute Shaderへのデータ設定や各種データの取得、整理などを行っています。

ParticleTarget

ParticleTargetはターゲットとなるMeshの各種情報を取得するためのラッパークラスです。
各種情報とは、頂点やUV、テクスチャ情報などパーティクルシステムで利用する情報です。

この派生型のクラスとして、テクスチャのピクセル位置に移動させるようなクラスも用意されています。
ただ基本概念は一緒なので説明は割愛します。

ParticleTargetGroup

ParticleTargetGroupParticleTargetをグループ化するためのクラスです。
今回のプロジェクトの要件では複数のモデルをひとつにまとめて扱う必要があったためこのグループクラスを定義しています。

このグループクラスでは、子どもに持つParticleTargetから頂点などの情報を取得しそれらをひとつにまとめる処理を担当しています。
最終的にこのグループ単位でパーティクルシステムにデータが渡され利用されます。

C#側では主にバッファの設定とCompute Shaderへの計算命令の実行(Dispatch)を行います。

バッファ周りの仕組みや概要については以前の記事を参考にしてください。

edom18.hateblo.jp

データの設定

今回のポイントはデータの変更処理にあります。

特にマトリクスの変更とターゲット位置の変更がメインです。
ということでマトリクスの変更処理部分を見てみます。

private void UpdateMatrices(ParticleTargetGroup group)
{
    group.UpdateMatrices();

    System.Array.Copy(group.MatrixData, 0, _matrixData, 0, group.MatrixData.Length);

    _matrixBuffer.SetData(_matrixData);
}

// ParticleTargetGroup.UpdateMatrices処理
public void UpdateMatrices()
{
    for (int i = 0; i < _targets.Length; i++)
    {
        _matrixData[i] = _targets[i].WorldMatrix; // localToWorldMatrixを返すプロパティ
    }
}

グルーピングされているターゲットのlocalToWorldMatrixを配列に詰めているだけです。

続いてターゲット位置の変更。

public void UpdateInitData(InitData[] updateData)
{
    if (updateData.Length > _initDataList.Length)
    {
        Debug.LogError("Init data list size is not enough to use.");
    }

    int len = updateData.Length > _initDataList.Length ? _initDataList.Length : updateData.Length;

    System.Array.Copy(updateData, _initDataList, len);

    _initDataListBuffer.SetData(_initDataList);
}

このInitDataが更新用データ構造体で、パーティクルひとつの更新情報になります。
このInitData配列を、各グループクラスが起動時にまとめあげて保持しているわけです。

そして変更のタイミングに応じて、すでに生成済みのInitData配列を渡すことで高速にターゲットを変更しているわけです。
なのでここをランタイムで生成して渡すことで完全に任意の位置にパーティクルを飛ばすことも可能になっています。

そしてこのあとに初期化用カーネルを呼び出してパーティクルの更新を行っています。

呼び出し部分は以下。

private void UpdateAllBuffers(int kernelId)
{
    SetBuffer(kernelId, _propertyDef.ParticleBufferID, _particleBuffer);
    SetBuffer(kernelId, _propertyDef.InitDataListID, _initDataListBuffer);
    SetBuffer(kernelId, _propertyDef.MatrixDataID, _matrixBuffer);
    SetBuffer(kernelId, _propertyDef.IndexBufferID, _indexBuffer);
}

private void Dispatch(int kernelId)
{
    _computeShader.Dispatch(kernelId, _maxCount / THREAD_NUM, 1, 1);
}

それぞれ必要なバッファを対象カーネルにセットし、そのあとでカーネルを起動します。
呼び出すカーネルは前述した初期化用カーネルです。

これを呼び出すことでパーティクルデータのバッファの内容が書き換わり、以後の位置更新ループによって別のターゲットへパーティクルが近づいていくことになります。

パーティクルのレンダリングについて

本パーティクルシステムの大まかな仕組みは当初から変わっていません。
しかしモック段階では問題にならなかったものも、実際のコンテンツに組み込むことで問題が出てくることが多々あります。

今回も同じように問題が発生しました。

問題というのは、モック段階ではパーティクルの移動のみに焦点を当てて実装を行っていたのですが、いざ組み込みを開始すると、パーティクルの移動先の対象モデルと同じ色にしないとなりません。
(でないとモデル形状になるだけの単色のパーティクルの塊になってしまう)

当然、対象モデルはテクスチャを持っており、頂点ごとに色など持っていません。
つまり、対象頂点位置の色をテクスチャから引っ張ってくる必要があるわけです。

しかし、パーティクルに利用するシェーダは当然ひとつだけで、GPUに転送できるデータにも限りがあります。
CPUのように柔軟にテクスチャ情報を取り出すわけにはいきません。
なにより、転送できるテクスチャの数にも限りがある上に、複数のテクスチャを配列以外でアクセスするには問題が多すぎました。

そこで、最初の『方針』のところでも書いたように、対象モデルのテクスチャをTexture2DArrayにまとめることで配列にし、かつ『どのテクスチャのどこをフェッチするか』という情報をパーティクル自体に持たせる必要があります。

ということを念頭に置いて、どうしているかのコード断片を抜き出すと以下のようになります。

// 頂点シェーダ
TransformParticle p = _Particles[id];

v2f o;

// 中略

o.uv.xy = p.uv.xy;
o.uv.z = p.targetId;

// フラグメントシェーダ
col = UNITY_SAMPLE_TEX2DARRAY(_Textures, i.uv);

Texture2DArrayでは最初の2要素(x, y)が通常のUV値として利用され、3要素目(z)が配列の添字として利用されます。
なのでそれを頂点シェーダからフラグメントシェーダへ転送し、UNITY_SAMPLE_TEX2DARRAYマクロを利用して対象の色をフェッチしている、というわけです。

正直、これが想定通りうまく行ったときは内心飛び跳ねていました。
移動自体は問題なく動いたものの、色味をつけるところで躓いていたので。

この仕組み自体はWindows, Mac OSX, iOS, Android, MagicLeapすべてで問題なく動いているので汎用的に使える仕組みだと思います。

なお、パーティクルを画面に表示するフローについては以前書いた以下の記事を参照ください。このパーティクルシステムを実装したあとに備忘録として書いた記事なのでほぼ内容はそのままです。

edom18.hateblo.jp

プロジェクトでのハマりポイント

最初のモック段階では問題なく動いていたものの、実際のプロジェクトでは画面になにも表示されないという問題が発生しました。

おそらく理由はメモリ不足。
ターゲットとなるモデルを複数指定していたのですが、同じモデルなのにも関わらず複数のグループに分割したために必要となるデータの重複が発生し、特にTexture2DArrayにまとめたテクスチャのデータ量が増えすぎたためだと思います。

なので参考にアップしてあるGitHubのコードではグループ化の最適化を行っています。
具体的には、サブグループという概念を取り入れてグループデータはひとつにまとめ、サブグループ単位でパーティクル位置を指定できるようにしました。

ただ今回の解説からは脱線するので説明は割愛します。

逆引きUniRx - 使用例から見る使い方

概要

UniRx。最近よく目にする気がしますが基本的な概念は『Reactive Extensions』です。
似た用語として『関数型リアクティブプログラミング:Functional Reactive Programming(FRP』がありますが概念としては違うようです。

ざっくりとRx系ライブラリを説明すると、連続した値を『ストリーム』と捉え、それをどう扱うかに焦点を当てたものです。
ストリームのイメージは、『なにか』がパイプの中を流れていくイメージです。
この『なにか』はデータであるかもしれないし、時間かもしれません。とにかく抽象化されたものがパイプ内を流れ、それに加工を加えながら処理するもの、という感じです。

今回はUniRxの使い方の説明は割愛します。使い方や基礎的なところは@toRisouPさんがとても詳しく記事を書いてくださっているのでそちらを見るといいでしょう。

qiita.com



今回書くのは、ある程度基礎が分かっていて概念も把握したものの『で、実務でどう使ったらいいんだ?』となった人向けに、具体的な事例を交えて逆引き的に利用できるようにまとめたものです。
(というか、完全に自分のためのメモです・・・w なので、今後使いやすい・思い出しやすいように逆引きで書いてるというわけです)

なので多分、随時更新されていくと思います。

ドラッグ処理

最近実装したものでドラッグ処理です。
例で使われている_raycasterはレイが当たっているかどうかを判定するためのクラスで、_inputControllerはコントローラのトリガー状態を提供してくれるものです。

例がだいぶ偏ったものになっていますが、VRでコントローラからレイを飛ばしてなにかをドラッグして加速度を計算、対象オブジェクトを移動させる、みたいなシーンを思い浮かべてください。

var startStream = this.UpdateAsObservable()
                      .Where(_ => _raycaster.IsHit && _inputController.IsTriggerDown);

var stopStream = this.UpdateAsObservable()
                     .Where(_ => !_raycaster.IsHit || _inputController.IsTriggerUp);

startStream
    .SelectMany(x => this.UpdateAsObservable())
    .TakeUntil(stopStream)
    .Select(_ => _raycaster.ResultRaycastHit.point)
    .Pairwise()
    .RepeatUntilDestroy(this)
    .Subscribe(Dragging)
    .AddTo(this);


private void Dragging(Pair<Vector3> points)
{
    Vector3 cur = _anyObj.transform.worldToLocalMatrix.MultiplyPoint3x4(points.Current);
    Vector3 prev = _anyObj.transform.worldToLocalMatrix.MultiplyPoint3x4(points.Previous);

    float velocity = (cur.x - prev.x) / Time.deltaTime;
    _acceleration += velocity / Time.deltaTime;
    _acceleration *= _attenuateOfAcceleration;

    float relativeVelocity = (_acceleration * Time.deltaTime * _coefOfVelocity) - _anyObj.Velocity;

    _anyObj.AddVelocity(relativeVelocity);
}

キモは以下の部分です。

startStream
    .SelectMany(x => this.UpdateAsObservable())
    .TakeUntil(stopStream)
    .Select(_ => _raycaster.ResultRaycastHit.point)
    .Pairwise()
    .RepeatUntilDestroy(this)
    .Subscribe(Dragging)
    .AddTo(this);

ここで行っていることは、『なにがしかのスタートタイミング(startStream)』から開始され、指定の終了ストリーム(stopStream)に値が流れてくるまで継続する、というものです。

そしてストリームが継続している間はレイのヒット位置をストリームに流し(Select)、それをペアにし(Pairwise)、オブジェクトが破棄されるまで継続する、というものです。

値が閾値を越えたら処理をする

例えば、速度が一定速度以上になり、またそれが一定速度以下になった、という閾値またぎを検知したい場合があるかと思います。
そんなときに利用できるのがDistinctUntilChangedフィルタです。

これは、同じ値が連続している間はその値を流さない、という動作をします。
つまり、特定の値(今回の例では速度)が閾値をまたいだときにtrue/falseを返すようにしておき、それをフィルタすることで閾値をまたいだことを検知することができます。

ちなみに以下の例でSkip(1)が入っているのは初期化時など最初に発生してしまうイベントを無視するために入れています。

this.UpdateAsObservable()
    .Select(_ => Mathf.Abs(Velocity) <= _stopLimit)
    .DistinctUntilChanged()
    .Where(x => x)
    .Skip(1)
    .Subscribe(_ => DoAnything())
    .AddTo(this);

一定時間経過したら無効化する

次はボタンなど『開始イベント』と『終了イベント』があり、かつ制限時間を設けたい場合の処理です。

以下のサンプルではまず、開始イベントにスペースキーのDown、終了イベントにスペースキーのUpを設定しています。
そしてさらに、制限時間(例では3秒)が経過した場合も終了するようになっています。

var startStream = this.UpdateAsObservable().Where(_ => Input.GetKeyDown(KeyCode.Space));
var stopStream = this.UpdateAsObservable().Where(_ => Input.GetKeyUp(KeyCode.Space));
var timeOut = Observable.Timer(System.TimeSpan.FromMilliseconds(3000)).Select(_ => false);

startStream
    .SelectMany(stopStream.Select(_ => true).Amb(timeOut ).First())
    .Where(x => x)
    .Subscribe(_ => DoHoge())
    .AddTo(this);

ここでの大事な点はAmbです。
Ambはストリームを合成し、どちらかのストリームに流れたものをそのままひとつのストリームとして流してくれるオペレータです。

なのでここでは『スペースキーUp』と『制限時間経過』を『終了イベント』として捉え、そのどちらかが流れたら終了するようにしています。

ボタン長押しを検知

次はシンプルな『ボタン長押し』の処理です。

1秒後に発火

まず最初は指定時間押し続けていたら発火するもの。

var clickDownStream = this.UpdateAsObservable().Where(_ => Input.GetKeyDown(KeyCode.Space));
var clickUpStream = this.UpdateAsObservable().Where(_ => Input.GetKeyUp(KeyCode.Space));

clickDownStream
    .SelectMany(_ => Observable.Interval(System.TimeSpan.FromSeconds(1)))
    .TakeUntil(clickUpStream)
    .DoOnCompleted(() =>
    {
        Debug.Log("Completed!");
    })
    .RepeatUntilDestroy(this)
    .Subscribe(_ =>
    {
        Debug.Log("pressing...");
    });

押している間、押下を検知

上記は『長押し』だけを検知するものでしたが、こちらは『押している間』のイベントも受け取れるようにしたものです。
使用想定としては、押している間だけ常に何か処理をする、みたいなケースです。

var clickDownStream = this.UpdateAsObservable().Where(_ => Input.GetKeyDown(KeyCode.Space));
var clickUpStream = this.UpdateAsObservable().Where(_ => Input.GetKeyUp(KeyCode.Space));

clickDownStream
    .SelectMany(_ => this.UpdateAsObservable())
    .TakeUntil(clickUpStream)
    .DoOnCompleted(() =>
    {
        Debug.Log("Completed!");
    })
    .RepeatUntilDestroy(this)
    .Subscribe(_ =>
    {
        Debug.Log("pressing...");
    });

一定時間経つ前にボタンが離された場合も処理

こちらはUnity開発者ギルドの質問チャンネルで質問した際に教えていただいた方法です。
上記では『押している間』のイベントを捉えることができましたが、『終了判定』は取れませんでした。
『終了判定』を加えたものが以下のものです。

var startStream = this.UpdateAsObservable().Where(_ => Input.GetKeyDown(KeyCode.Space));
var stopStream = this.UpdateAsObservable().Where(_ => Input.GetKeyUp(KeyCode.Space));
var timeOut = Observable.Timer(System.TimeSpan.FromSeconds(5)).AsUnitObservable();

startStream
    .SelectMany(stopStream.Amb(timeOut))
    .First()
    .RepeatUntilDestroy(this)
    .Subscribe(_ =>
    {
        Debug.Log("hoge");
    });

値の変化監視(true / false)

値の変化監視はよくあるニーズだと思います。
例えば、前回はfalseだったものがtrueになったときだけ処理したい、などです。

using UniRx.Triggers; // ←UpdateAsObservableを使うにはこれが必要

this.UpdateAsObservable()
    .Select(_ => IsHoge())
    .DistinctUntilChanged()
    .Where(x => x)
    .Subscribe(_ =>
    {
        Debug.Log("Is Hoge");
    });

ObserveEveryValueChangedを使ったほうがもっとシンプルに書ける

this.ObserveEveryValueChanged(x => x.IsHoge())
    .Where(x => x)
    .Subscribe(_ =>
    {
        Debug.Log("Is Hoge");
    });

UniRxでコルーチンを使ったアニメーション

こちらはコルーチンを交えてUniRxでアニメーションを行う例です。
Observable.FromCoroutine<T>を使うことでコルーチンをストリームに変え、かつコルーチン内で計算した結果を受け取ることができます。

使い方は、`に渡すラムダの引数にObservableが渡ってくるのでそれをコルーチンに渡し、そのObservableを介してOnNext`を呼んでやることで実現しています。

private void DoAnimation()
{
    Observable.FromCoroutine<float>(o => Animation(o, duration))
        .SubscribeWithState3(transform, transform.position, position,
        (t, trans, start, goal) =>
        {
            trans.position = Vector3.Lerp(start, goal, t);
        })
        .AddTo(this);
}

private IEnumerator Animation(IObserver<float> observer, float duration)
{
    float timer = duration;

    while (timer >= 0)
    {
        timer -= Time.deltaTime;

        float t = 1f - (timer / duration);

        observer.OnNext(t);

        yield return null;
    }
}

参考記事

以下は参考にした記事のリンク集です。

qiita.com

noriok.hatenadiary.jp

developer.aiming-inc.com

qiita.com

qiita.com

www.slideshare.net

rxmarbles.com