Practice 3-4: PyHARK による深層学習を用いた音源定位

深層学習を用いた音源定位

  • User_node_VIVIT(深層学習版音源定位ノード)をリリース
    • pytorchで作成された学習済みモデルをpyharkへ導入
    • 従来のMUSIC法と比較してロバスト性が向上
    • 1~2音源の到来方向を推定可能
  • 使用上の注意
    • User_node_VIVITはGPUの使用を推奨
    • 精度はCPU上でも変わらないが,実行時間が遅い
    • 本講義でのデモは全てGPU上で実行

動かしてみる(デモ)

  • 入力音声
    • 雑音が混じった男性と女性の音声 (rec_060_300_0_7.wav)
    • 男性と女性は60,300度方向から30秒間発声
  • プログラムを実行する
  • $ python3 practice3-4.py
  • 出力結果
    • 60,300度方向から音を検出できていることを確認
    PyHARK 深層学習による音源定位のデモ

VIVITとMUSICとの比較

PyHARK 深層学習とMUSICとの比較
User_node_VIVITは背景雑音にロバストで精度良く推定可能.

深層学習を利用した音源定位のプログラム

Practice3-1 での MUSIC を User_node_VIVIT に変更する

MUSICをVIVITに変更

User Node VIVIT を使用した Python プログラム

1.practice3-1.py の作成するノードの変更

  • practice3-1 でMUSIC関係のノードを
  • class HARK_Localization(hark.base.NetworkDef):
        def build(. . .):
    
            try:
                # 必要なノードを作成する
                node_cm_identity_matrix = network.create(
                    hark.core.CMIdentityMatrix,
                    dispatch=hark.base.RepeatDispatcher
                )
                node_constant__for_operation_flag = network.create(
                    hark.node.Constant,
                    dispatch=hark.base.RepeatDispatcher
                )
                node_localize_music = network.create(hark.core.LocalizeMUSIC)
                node_source_tracker = network.create(hark.core.SourceTracker)
                node_source_interval_extender = network.create(hark.core.SourceIntervalExtender)
                node_plotsource_kivy = network.create(plotQuickSourceKivy)
    		
  • practice3-4 で User_node_VIVIT に変更
  • class HARK_Localization(hark.base.NetworkDef):
        def build(. . .):
            try:
                # 必要なノードを作成する
                node_VIVIT = network.create(User_node_VIVIT.User_node_VIVIT)
                node_source_tracker = network.create(hark.core.SourceTracker)
                node_source_interval_extender = network.create(hark.core.SourceIntervalExtender)
                node_plotsource_kivy = network.create(plotQuickSourceKivy)
    		

2. practice3-1.py のノード接続の変更

  • practice3-1 でMUSIC関係のノードを
  • class HARK_Localization(hark.base.NetworkDef):
        def build(. . .):
    
            try:
     
                # ノード間の接続(データの流れ)とパラメータを記述する
                 node_cm_identity_matrix
    	                 .add_input("NB_CHANNELS", 8)
                         .add_input("LENGTH", 512)
                ,
                node_constant__for_operation_flag.add_input("VALUE", True)
    
                node_localize_music
                         .add_input("INPUT", input["SPEC"])
                         .add_input("NOISECM", node_cm_identity_matrix["OUTPUT"])
                         .add_input("OPERATION_FLAG", node_constant__for_operation_flag["OUTPUT"])
                         .add_input("MUSIC_ALGORITHM", "SEVD")
                         .add_input("A_MATRIX", "tf.zip")
                         .add_input("WINDOW_TYPE", "MIDDLE")
                         .add_input("LOWER_BOUND_FREQUENCY", 3000)
                         .add_input("UPPER_BOUND_FREQUENCY", 6000)
                         .add_input("SPECTRUM_WEIGHT_TYPE", "A_Characteristic")
                         .add_input("ENABLE_EIGENVALUE_WEIGHT", False)
                         .add_input("ENABLE_OUTPUT_SPECTRUM", True)
                ,
                node_source_tracker
                         .add_input("INPUT", node_localize_music["OUTPUT"])
                         .add_input("THRESH", 25.0)
                         .add_input("PAUSE_LENGTH", 1200.0)
                ,
                node_source_interval_extender
                         .add_input("SOURCES", node_source_tracker["OUTPUT"])
                         .add_input("PREROLL_LENGTH", 80)
                ,
    		
  • practice3-4 で User_node_VIVIT に変更
  • class HARK_Localization(hark.base.NetworkDef):
        def build(. . .):
     
           try:
     
                # ノード間の接続(データの流れ)とパラメータを記述する
                node_VIVIT
                         .add_input("INPUT", input["SPEC"])
                         .add_input("BATCH_SIZE, 64)
    		    ,
                node_source_tracker
                         .add_input("INPUT", node_VIVIT["OUTPUT"])
                         .add_input("THRESH", 0.90)   #目安は0.8~0.9  
                         .add_input("PAUSE_LENGTH", 1200.0)
                ,
                node_source_interval_extender
                         .add_input("SOURCES", node_source_tracker["OUTPUT"])
                         .add_input("PREROLL_LENGTH", 80)
                ,
    		

まとめ

  • User_node_VIVIT(深層学習版音源定位ノード)をリリース
    • MUSIC法と比較してノイズ下での精度が向上
    • 1~2音源の到来方向を推定可能
  • 使い方はpractice3-1とほとんど同じ
    • CMIdentityMatrixとLocalizeMUSICをUser_node_VIVITに変更

付録: プログラムの全コード

practice3-1.pyの全コード

#!/usr/bin/env python

import sys
import threading
import time

import numpy as np
import soundfile as sf

import hark
import hark.base
import hark.node
import hark.core
from hark.modules.plot.kivynodes.plotQuickSourceKivy import plotQuickSourceKivy

class HARK_Localization(hark.base.NetworkDef):
    def build(
            self,
            network: hark.base.Network,
            input: hark.base.DataSourceMap,
            output: hark.base.DataSinkMap
    ):
        try:
            # 必要なノードを作成する
            node_cm_identity_matrix = network.create(
                hark.core.CMIdentityMatrix,
                dispatch=hark.base.RepeatDispatcher
            )
            node_constant__for_operation_flag = network.create(
                hark.node.Constant,
                dispatch=hark.base.RepeatDispatcher
            )
            node_localize_music = network.create(hark.core.LocalizeMUSIC)
            node_source_tracker = network.create(hark.core.SourceTracker)
            node_source_interval_extender = network.create(hark.core.SourceIntervalExtender)
            node_plotsource_kivy = network.create(plotQuickSourceKivy)

            # ノード間の接続(データの流れ)とパラメータを記述する
            node_cm_identity_matrix.add_input("NB_CHANNELS", 8)
            node_cm_identity_matrix.add_input("LENGTH", 512)

            node_constant__for_operation_flag.add_input("VALUE", True)

            node_localize_music.add_input("INPUT", input["SPEC"])
            node_localize_music.add_input("NOISECM", node_cm_identity_matrix["OUTPUT"])
            node_localize_music.add_input("OPERATION_FLAG", node_constant__for_operation_flag["OUTPUT"])
            node_localize_music.add_input("MUSIC_ALGORITHM", "SEVD")
            node_localize_music.add_input("A_MATRIX", "tf.zip")
            node_localize_music.add_input("WINDOW_TYPE", "MIDDLE")
            node_localize_music.add_input("LOWER_BOUND_FREQUENCY", 3000)
            node_localize_music.add_input("UPPER_BOUND_FREQUENCY", 6000)
            node_localize_music.add_input("SPECTRUM_WEIGHT_TYPE", "A_Characteristic")
            node_localize_music.add_input("ENABLE_EIGENVALUE_WEIGHT", False)
            node_localize_music.add_input("ENABLE_OUTPUT_SPECTRUM", True)

            node_source_tracker.add_input("INPUT", node_localize_music["OUTPUT"])
            node_source_tracker.add_input("THRESH", 25.0)
            node_source_tracker.add_input("PAUSE_LENGTH", 1200.0)

            node_source_interval_extender.add_input("SOURCES", node_source_tracker["OUTPUT"])
            node_source_interval_extender.add_input("PREROLL_LENGTH", 80)

            node_plotsource_kivy.add_input("SOURCES", node_source_interval_extender["OUTPUT"])

            output.add_input("OUTPUT", node_source_interval_extender["OUTPUT"])
            
            # ネットワークに含まれるノードの一覧を含むリストを作成する
            r = [
                node_cm_identity_matrix,
                node_constant__for_operation_flag,
                node_localize_music,
                node_source_tracker,
                node_source_interval_extender,
                node_plotsource_kivy,
            ]
        except BaseException as ex:
            print(f"error: {repr(ex)}")

        # ノード一覧のリストを返す
        return r

class HARK_Loop(hark.base.NetworkDef):

    def build(
            self,
            network: hark.base.Network,
            input: hark.base.DataSourceMap,
            output: hark.base.DataSinkMap
    ):
        try:
            # 必要なノードを作成する
            node_audio_stream_from_memory = network.create(
                hark.core.AudioStreamFromMemory,
                dispatch=hark.base.TriggeredMultiShotDispatcher,
                name="AudioStreamFromMemory"
            )
            node_multi_fft = network.create(hark.core.MultiFFT)
            node_sub_localization = network.create(
                HARK_Localization,
                name="Localization"
            )

            # ノード間の接続(データの流れ)とパラメータを記述し,
            # ネットワークに含まれるノードの一覧を含むリストを作成する
            r = [
                node_audio_stream_from_memory
                    .add_input("INPUT", input["INPUT"])
                    .add_input("CHANNEL_COUNT", 8)
                ,
                node_multi_fft
                    .add_input("INPUT", node_audio_stream_from_memory["AUDIO"])
                ,
                node_sub_localization
                    .add_input("SPEC", node_multi_fft["OUTPUT"])
                ,
            ]
            output.add_input("OUTPUT", node_sub_localization["OUTPUT"])
        except BaseException as ex:
            print(f"error: {repr(ex)}")

        # ノード一覧のリストを返す
        return r

class HARK_Main(hark.base.NetworkDef):
    """
    メインネットワークに相当するクラス。
    入力として8ch音響信号を受け取り、
    フーリエ変換、MUSIC法による音源定位、音源追跡を行い、
    その結果を図示する。
    """

    def build(
            self,
            network: hark.base.Network,
            input: hark.base.DataSourceMap,
            output: hark.base.DataSinkMap
    ):
        try:
            # 必要なノードを作成する.
            # - 全体の入出力を扱う Publisher と Subscriber
            # - HARK_Loop サブネット
            node_publisher = network.create(
                hark.node.PublishData,
                dispatch=hark.base.RepeatDispatcher,
                name="Publisher"
            )
            node_subscriber = network.create(
                hark.node.SubscribeData,
                name="Subscriber"
            )
            loop = network.create(
                HARK_Loop,
                name="HARK_Loop"
            )

            # ノード間の接続(データの流れ)とパラメータを記述し,
            # ネットワークに含まれるノードの一覧を含むリストを作成する
            r = [
                loop
                    .add_input("INPUT", node_publisher["OUTPUT"])
                ,
                node_subscriber
                    .add_input("INPUT", loop["OUTPUT"])
                ,
            ]
        except BaseException as ex:
            print(f"error: {repr(ex)}")

        # ノード一覧のリストを返す
        return r

def main():
    # コマンドライン引数の処理
    if len(sys.argv) < 2:
        print("no input file")
        return
    wavfilename = sys.argv[1]

    # メインネットワークを構築
    network = hark.base.Network.from_networkdef(HARK_Main, name="HARK_Main")

    # メインネットワークへの入出力を構築
    publisher = network.query_nodedef("Publisher")
    subscriber = network.query_nodedef("Subscriber")

    # subscriber がデータを受け取ったとき
    # (メインネットワークが結果を出力したとき)に
    # 実行される動作を定義する。
    # ここでは pass を用いることで「何もしない」ことを指示する。
    def received(data):
        pass

    subscriber.receive = received

    try:
        # ネットワーク実行用スレッドを開始
        th = threading.Thread(target=network.execute)
        th.start()

        # 入力ファイル読み込み
        audio, samplerate = sf.read(wavfilename, dtype=np.int16)

        # フレーム分割
        advance = 160
        # 2023講習会VMは numpy==1.21.5
        # for numpy>=1.20.0
        frames = np.lib.stride_tricks.sliding_window_view(audio, advance, axis=0)[::advance, :, :]

        # フレームごとに処理
        for frame in frames:
            # もしネットワーク実行用スレッドが停止していたら
            # ループを抜け処理全体を停止させる
            if not th.is_alive():
                break

            # ネットワークに1フレーム分の音響信号を送信
            publisher.push(frame)

            # リアルタイム処理と同等程度の処理時間となるように
            # 音響信号送信間隔を調整する
            # time.sleep(advance / samplerate)

    except BaseException as ex:
        print(f"error: {repr(ex)}")
    except:
        network.stop()
    finally:
        # 終了処理
        publisher.close()
        if th.ident is not None:
            th.join()

if __name__ == "__main__":
    main()

# end of file

 

practice3-4.pyの全コード

#! /usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import threading
import time

import matplotlib.pyplot as plt

import numpy 
import hark
import hark.base
import hark.node
import hark.core

import hark.modules.localization.dlsslnodes as dlssl
import struct
import wave
import soundfile
import hark.modules.plot.kivynodes as p
import argparse
import tempfile

def check_version_requirements(package, version):
    for t in [x for x in zip(package.__version__.split('.'), version.split('.'))]:
        if t[0] > t[1]:
            return True
        elif t[0] < t[1]:
            return False
    return True


class HARK_Localization(hark.base.NetworkDef):
    def build(self,
              network: hark.base.Network,
              input:   hark.base.DataSourceMap,
              output:  hark.base.DataSinkMap):

        try:
            node_VIVIT = network.create(dlssl.User_node_VIVIT.User_node_VIVIT)
            node_source_tracker = network.create(hark.core.SourceTracker)
            node_plotsource_kivy = network.create(p.plotQuickSourceKivy.plotQuickSourceKivy)
        except BaseException as ex:
            print(ex)

        try:
            r = [
                node_VIVIT
                     .add_input("INPUT", input["SPEC"])
                     .add_input("BATCH_SIZE", 4)
                ,
                node_source_tracker
                    .add_input("INPUT", node_VIVIT["OUTPUT"])
                    #.add_input("THRESH", 20.0)
                    .add_input("THRESH", 0.90)
                    #.add_input("PAUSE_LENGTH", 800.0)
                    .add_input("PAUSE_LENGTH", 20)
                    .add_input("MIN_SRC_INTERVAL", 15.0)
                    #.add_input("MIN_ID", 0)
                    .add_input("DEBUG", True)
                ,
                node_plotsource_kivy
                    .add_input("SOURCES", node_source_tracker["OUTPUT"])
                ,
            ]
            
            output.add_input("OUTPUT", node_source_tracker["OUTPUT"])

        except BaseException as ex:
            print('error: {}'.format(ex))

        return r


class HARK_MAIN_LOOP(hark.base.NetworkDef):
    def build(self,
              network: hark.base.Network,
              input:   hark.base.DataSourceMap,
              output:  hark.base.DataSinkMap):

        try:
            node_audio_stream_from_memory = network.create(hark.core.AudioStreamFromMemory, dispatch=hark.base.TriggeredMultiShotDispatcher, name="AudioStreamFromMemory")
            #node_vadzc =network.create(hark.core.VADZC)
            node_multi_fft = network.create(hark.core.MultiFFT)
            node_sub_localization = network.create(HARK_Localization, name="Localization")
        except BaseException as ex:
            print(ex)

        try:
            r = [
                node_audio_stream_from_memory
                    .add_input("INPUT", input["INPUT"])
                    #.add_input("SAMPLING_RATE", 16000)
                    #.add_input("SAMPLE_BITS", 16)
                    #.add_input("CHANNEL_COUNT", 1)
                    .add_input("CHANNEL_COUNT", 8)
                    #.add_input("LENGTH", 512)
                    #.add_input("ADVANCE", 160)
                    #.add_input("USE_WAIT", False)
                    #.add_input("MAXIMUM_BUFFER_LENGTH", 10.0)
                ,
                #node_vadzc
                #    .add_input("INPUT", node_audio_stream_from_memory["AUDIO"])
                #    .add_input("CHANNEL_COUNT", 8)
                #    .add_input("LEVEL_THRESHOLD",0)
                #,
                node_multi_fft
                    #.add_input("INPUT", node_vadzc["OUTPUT"])
                    .add_input("INPUT", node_audio_stream_from_memory["AUDIO"])
                    # .add_input("LENGTH", 512)
                    .add_input("WINDOW", "HANNING")
                    # .add_input("WINDOW_LENGTH", 512)
                ,
                node_sub_localization
                    .add_input("SPEC", node_multi_fft["OUTPUT"])
                ,
            ]
            
            output.add_input("OUTPUT", node_sub_localization["OUTPUT"])

        except BaseException as ex:
            print('error: {}'.format(ex))

        return r


class HARK_Main(hark.base.NetworkDef):
    def __init__(self):
        hark.base.NetworkDef.__init__(self)

    def build(self,
              network: hark.base.Network,
              input:   hark.base.DataSourceMap,
              output:  hark.base.DataSinkMap):

        try:
            node_publisher = network.create(hark.node.PublishData, dispatch=hark.base.RepeatDispatcher, name="Publisher")
            node_subscriber = network.create(hark.node.SubscribeData, name="Subscriber")
            loop = network.create(HARK_MAIN_LOOP, name="HARK_Main_Loop")

        except BaseException as ex:
            print(ex)

        try:
            r = [
                loop
                    .add_input("INPUT", node_publisher["OUTPUT"]),
                node_subscriber.add_input("INPUT", loop["OUTPUT"])
                ,
            ]

        except BaseException as ex:
            print(ex)

        return r

def received(data):
    pass
    #print('>>>> received: {}'.format(data))

def main1(args=sys.argv[1:]):

    if len(args) > 0:
        if args[0] == '--online':
            use_online = True
        elif args[0] == '--offline':
            use_online = False
        else:
            raise BaseException("Unexpected argument {}, use --online or --offline".format(args[0]))
    else:
        use_online = True # default mode

    if use_online:
        # online
        network = hark.base.Network.from_networkdef(HARK_Main, name="HARK_Main1")

        publisher = network.query_nodedef("Publisher")
        subscriber = network.query_nodedef("Subscriber")
        subscriber.receive = received

        try:
            
            th = threading.Thread(target=network.execute)
            th.start()

            # Generate audio frames
            advance = 160
            audio, rate = soundfile.read('rec_060_300_0_7.wav', dtype=numpy.int16)
            print(f"rate={rate}")
            frames = numpy.array([[[0] * audio.shape[1]] * advance])
            if not check_version_requirements(numpy, "1.20.0"):
                frames = numpy.lib.stride_tricks.as_strided(audio, shape=(int(audio.shape[0]/advance), advance, audio.shape[1]), strides=(advance * audio.shape[1] * audio.strides[1], audio.shape[1]*audio.strides[1], audio.strides[1]))
            else:
                frames = numpy.lib.stride_tricks.sliding_window_view(audio, advance, axis=0)[::advance, :, :]
            print("DEBUG:\n  frames.shape={}\n".format(frames.shape))
            for t in range(len(frames)):
                if not th.is_alive():
                    break
                #print('<<<< send: count={}, max={}'.format(t, numpy.max(frames[t])))
                #print(frames[t].shape)
                #publisher.push(numpy.reshape(frames[t], frames[t].shape[0] * frames[t].shape[1]).astype(numpy.int16))
                publisher.push(frames[t].astype(numpy.int16))
                #publisher.push(frames[t])
                #time.sleep(0.01)

        except BaseException as ex:
            print(ex)
        except:
            network.stop()
        finally:
            publisher.close()
            if th.ident is not None:
                th.join()

    else:
        raise NotImplementedError("Sorry, offline mode for network-def is not implemented yet.")
        # offline
        data = numpy.zeros(160000, dtype=numpy.int16)
        data[(numpy.arange(len(data)) // 200) % 2 == 1] = (4096 * numpy.random.randn(80000)).astype(numpy.int16)

        network = hark.base.Network.from_networkdef(HARK_Recording)
        output = network(INPUT=data)
        #print(output.OUTPUT)


if __name__ == '__main__':
    start = time.time()
    main1(sys.argv[1:])
    end = time.time()
    print(f"Total_time={end-start}")
    #main()

# end of file
以上で講習会は終了です。お疲れ様。