Practice 3-4: PyHARK による深層学習を用いた音源定位
深層学習を用いた音源定位
- User_node_VIVIT(深層学習版音源定位ノード)をリリース
- pytorchで作成された学習済みモデルをpyharkへ導入
- 従来のMUSIC法と比較してロバスト性が向上
- 1~2音源の到来方向を推定可能
- 使用上の注意
- User_node_VIVITはGPUの使用を推奨
- 精度はCPU上でも変わらないが,実行時間が遅い
- 本講義でのデモは全てGPU上で実行
動かしてみる(デモ)
- 入力音声
- 雑音が混じった男性と女性の音声 (
rec_060_300_0_7.wav
) - 男性と女性は60,300度方向から30秒間発声
- プログラムを実行する
$ python3 practice3-4.py
- 60,300度方向から音を検出できていることを確認
VIVITとMUSICとの比較
User_node_VIVITは背景雑音にロバストで精度良く推定可能.
深層学習を利用した音源定位のプログラム
Practice3-1 での MUSIC を User_node_VIVIT に変更する
User Node VIVIT を使用した Python プログラム
1.practice3-1.py の作成するノードの変更
- practice3-1 でMUSIC関係のノードを
class HARK_Localization(hark.base.NetworkDef): def build(. . .): try: # 必要なノードを作成する node_cm_identity_matrix = network.create( hark.core.CMIdentityMatrix, dispatch=hark.base.RepeatDispatcher ) node_constant__for_operation_flag = network.create( hark.node.Constant, dispatch=hark.base.RepeatDispatcher ) node_localize_music = network.create(hark.core.LocalizeMUSIC) node_source_tracker = network.create(hark.core.SourceTracker) node_source_interval_extender = network.create(hark.core.SourceIntervalExtender) node_plotsource_kivy = network.create(plotQuickSourceKivy)
class HARK_Localization(hark.base.NetworkDef): def build(. . .): try: # 必要なノードを作成する node_VIVIT = network.create(User_node_VIVIT.User_node_VIVIT) node_source_tracker = network.create(hark.core.SourceTracker) node_source_interval_extender = network.create(hark.core.SourceIntervalExtender) node_plotsource_kivy = network.create(plotQuickSourceKivy)
2. practice3-1.py のノード接続の変更
- practice3-1 でMUSIC関係のノードを
class HARK_Localization(hark.base.NetworkDef): def build(. . .): try: # ノード間の接続(データの流れ)とパラメータを記述する node_cm_identity_matrix .add_input("NB_CHANNELS", 8) .add_input("LENGTH", 512) , node_constant__for_operation_flag.add_input("VALUE", True) node_localize_music .add_input("INPUT", input["SPEC"]) .add_input("NOISECM", node_cm_identity_matrix["OUTPUT"]) .add_input("OPERATION_FLAG", node_constant__for_operation_flag["OUTPUT"]) .add_input("MUSIC_ALGORITHM", "SEVD") .add_input("A_MATRIX", "tf.zip") .add_input("WINDOW_TYPE", "MIDDLE") .add_input("LOWER_BOUND_FREQUENCY", 3000) .add_input("UPPER_BOUND_FREQUENCY", 6000) .add_input("SPECTRUM_WEIGHT_TYPE", "A_Characteristic") .add_input("ENABLE_EIGENVALUE_WEIGHT", False) .add_input("ENABLE_OUTPUT_SPECTRUM", True) , node_source_tracker .add_input("INPUT", node_localize_music["OUTPUT"]) .add_input("THRESH", 25.0) .add_input("PAUSE_LENGTH", 1200.0) , node_source_interval_extender .add_input("SOURCES", node_source_tracker["OUTPUT"]) .add_input("PREROLL_LENGTH", 80) ,
class HARK_Localization(hark.base.NetworkDef): def build(. . .): try: # ノード間の接続(データの流れ)とパラメータを記述する node_VIVIT .add_input("INPUT", input["SPEC"]) .add_input("BATCH_SIZE, 64) , node_source_tracker .add_input("INPUT", node_VIVIT["OUTPUT"]) .add_input("THRESH", 0.90) #目安は0.8~0.9 .add_input("PAUSE_LENGTH", 1200.0) , node_source_interval_extender .add_input("SOURCES", node_source_tracker["OUTPUT"]) .add_input("PREROLL_LENGTH", 80) ,
まとめ
- User_node_VIVIT(深層学習版音源定位ノード)をリリース
- MUSIC法と比較してノイズ下での精度が向上
- 1~2音源の到来方向を推定可能
- 使い方はpractice3-1とほとんど同じ
- CMIdentityMatrixとLocalizeMUSICをUser_node_VIVITに変更
付録: プログラムの全コード
practice3-1.py
の全コード
#!/usr/bin/env python import sys import threading import time import numpy as np import soundfile as sf import hark import hark.base import hark.node import hark.core from hark.modules.plot.kivynodes.plotQuickSourceKivy import plotQuickSourceKivy class HARK_Localization(hark.base.NetworkDef): def build( self, network: hark.base.Network, input: hark.base.DataSourceMap, output: hark.base.DataSinkMap ): try: # 必要なノードを作成する node_cm_identity_matrix = network.create( hark.core.CMIdentityMatrix, dispatch=hark.base.RepeatDispatcher ) node_constant__for_operation_flag = network.create( hark.node.Constant, dispatch=hark.base.RepeatDispatcher ) node_localize_music = network.create(hark.core.LocalizeMUSIC) node_source_tracker = network.create(hark.core.SourceTracker) node_source_interval_extender = network.create(hark.core.SourceIntervalExtender) node_plotsource_kivy = network.create(plotQuickSourceKivy) # ノード間の接続(データの流れ)とパラメータを記述する node_cm_identity_matrix.add_input("NB_CHANNELS", 8) node_cm_identity_matrix.add_input("LENGTH", 512) node_constant__for_operation_flag.add_input("VALUE", True) node_localize_music.add_input("INPUT", input["SPEC"]) node_localize_music.add_input("NOISECM", node_cm_identity_matrix["OUTPUT"]) node_localize_music.add_input("OPERATION_FLAG", node_constant__for_operation_flag["OUTPUT"]) node_localize_music.add_input("MUSIC_ALGORITHM", "SEVD") node_localize_music.add_input("A_MATRIX", "tf.zip") node_localize_music.add_input("WINDOW_TYPE", "MIDDLE") node_localize_music.add_input("LOWER_BOUND_FREQUENCY", 3000) node_localize_music.add_input("UPPER_BOUND_FREQUENCY", 6000) node_localize_music.add_input("SPECTRUM_WEIGHT_TYPE", "A_Characteristic") node_localize_music.add_input("ENABLE_EIGENVALUE_WEIGHT", False) node_localize_music.add_input("ENABLE_OUTPUT_SPECTRUM", True) node_source_tracker.add_input("INPUT", node_localize_music["OUTPUT"]) node_source_tracker.add_input("THRESH", 25.0) node_source_tracker.add_input("PAUSE_LENGTH", 1200.0) node_source_interval_extender.add_input("SOURCES", node_source_tracker["OUTPUT"]) node_source_interval_extender.add_input("PREROLL_LENGTH", 80) node_plotsource_kivy.add_input("SOURCES", node_source_interval_extender["OUTPUT"]) output.add_input("OUTPUT", node_source_interval_extender["OUTPUT"]) # ネットワークに含まれるノードの一覧を含むリストを作成する r = [ node_cm_identity_matrix, node_constant__for_operation_flag, node_localize_music, node_source_tracker, node_source_interval_extender, node_plotsource_kivy, ] except BaseException as ex: print(f"error: {repr(ex)}") # ノード一覧のリストを返す return r class HARK_Loop(hark.base.NetworkDef): def build( self, network: hark.base.Network, input: hark.base.DataSourceMap, output: hark.base.DataSinkMap ): try: # 必要なノードを作成する node_audio_stream_from_memory = network.create( hark.core.AudioStreamFromMemory, dispatch=hark.base.TriggeredMultiShotDispatcher, name="AudioStreamFromMemory" ) node_multi_fft = network.create(hark.core.MultiFFT) node_sub_localization = network.create( HARK_Localization, name="Localization" ) # ノード間の接続(データの流れ)とパラメータを記述し, # ネットワークに含まれるノードの一覧を含むリストを作成する r = [ node_audio_stream_from_memory .add_input("INPUT", input["INPUT"]) .add_input("CHANNEL_COUNT", 8) , node_multi_fft .add_input("INPUT", node_audio_stream_from_memory["AUDIO"]) , node_sub_localization .add_input("SPEC", node_multi_fft["OUTPUT"]) , ] output.add_input("OUTPUT", node_sub_localization["OUTPUT"]) except BaseException as ex: print(f"error: {repr(ex)}") # ノード一覧のリストを返す return r class HARK_Main(hark.base.NetworkDef): """ メインネットワークに相当するクラス。 入力として8ch音響信号を受け取り、 フーリエ変換、MUSIC法による音源定位、音源追跡を行い、 その結果を図示する。 """ def build( self, network: hark.base.Network, input: hark.base.DataSourceMap, output: hark.base.DataSinkMap ): try: # 必要なノードを作成する. # - 全体の入出力を扱う Publisher と Subscriber # - HARK_Loop サブネット node_publisher = network.create( hark.node.PublishData, dispatch=hark.base.RepeatDispatcher, name="Publisher" ) node_subscriber = network.create( hark.node.SubscribeData, name="Subscriber" ) loop = network.create( HARK_Loop, name="HARK_Loop" ) # ノード間の接続(データの流れ)とパラメータを記述し, # ネットワークに含まれるノードの一覧を含むリストを作成する r = [ loop .add_input("INPUT", node_publisher["OUTPUT"]) , node_subscriber .add_input("INPUT", loop["OUTPUT"]) , ] except BaseException as ex: print(f"error: {repr(ex)}") # ノード一覧のリストを返す return r def main(): # コマンドライン引数の処理 if len(sys.argv) < 2: print("no input file") return wavfilename = sys.argv[1] # メインネットワークを構築 network = hark.base.Network.from_networkdef(HARK_Main, name="HARK_Main") # メインネットワークへの入出力を構築 publisher = network.query_nodedef("Publisher") subscriber = network.query_nodedef("Subscriber") # subscriber がデータを受け取ったとき # (メインネットワークが結果を出力したとき)に # 実行される動作を定義する。 # ここでは pass を用いることで「何もしない」ことを指示する。 def received(data): pass subscriber.receive = received try: # ネットワーク実行用スレッドを開始 th = threading.Thread(target=network.execute) th.start() # 入力ファイル読み込み audio, samplerate = sf.read(wavfilename, dtype=np.int16) # フレーム分割 advance = 160 # 2023講習会VMは numpy==1.21.5 # for numpy>=1.20.0 frames = np.lib.stride_tricks.sliding_window_view(audio, advance, axis=0)[::advance, :, :] # フレームごとに処理 for frame in frames: # もしネットワーク実行用スレッドが停止していたら # ループを抜け処理全体を停止させる if not th.is_alive(): break # ネットワークに1フレーム分の音響信号を送信 publisher.push(frame) # リアルタイム処理と同等程度の処理時間となるように # 音響信号送信間隔を調整する # time.sleep(advance / samplerate) except BaseException as ex: print(f"error: {repr(ex)}") except: network.stop() finally: # 終了処理 publisher.close() if th.ident is not None: th.join() if __name__ == "__main__": main() # end of file
practice3-4.py
の全コード
#! /usr/bin/env python # -*- coding: utf-8 -*- import sys import threading import time import matplotlib.pyplot as plt import numpy import hark import hark.base import hark.node import hark.core import hark.modules.localization.dlsslnodes as dlssl import struct import wave import soundfile import hark.modules.plot.kivynodes as p import argparse import tempfile def check_version_requirements(package, version): for t in [x for x in zip(package.__version__.split('.'), version.split('.'))]: if t[0] > t[1]: return True elif t[0] < t[1]: return False return True class HARK_Localization(hark.base.NetworkDef): def build(self, network: hark.base.Network, input: hark.base.DataSourceMap, output: hark.base.DataSinkMap): try: node_VIVIT = network.create(dlssl.User_node_VIVIT.User_node_VIVIT) node_source_tracker = network.create(hark.core.SourceTracker) node_plotsource_kivy = network.create(p.plotQuickSourceKivy.plotQuickSourceKivy) except BaseException as ex: print(ex) try: r = [ node_VIVIT .add_input("INPUT", input["SPEC"]) .add_input("BATCH_SIZE", 4) , node_source_tracker .add_input("INPUT", node_VIVIT["OUTPUT"]) #.add_input("THRESH", 20.0) .add_input("THRESH", 0.90) #.add_input("PAUSE_LENGTH", 800.0) .add_input("PAUSE_LENGTH", 20) .add_input("MIN_SRC_INTERVAL", 15.0) #.add_input("MIN_ID", 0) .add_input("DEBUG", True) , node_plotsource_kivy .add_input("SOURCES", node_source_tracker["OUTPUT"]) , ] output.add_input("OUTPUT", node_source_tracker["OUTPUT"]) except BaseException as ex: print('error: {}'.format(ex)) return r class HARK_MAIN_LOOP(hark.base.NetworkDef): def build(self, network: hark.base.Network, input: hark.base.DataSourceMap, output: hark.base.DataSinkMap): try: node_audio_stream_from_memory = network.create(hark.core.AudioStreamFromMemory, dispatch=hark.base.TriggeredMultiShotDispatcher, name="AudioStreamFromMemory") #node_vadzc =network.create(hark.core.VADZC) node_multi_fft = network.create(hark.core.MultiFFT) node_sub_localization = network.create(HARK_Localization, name="Localization") except BaseException as ex: print(ex) try: r = [ node_audio_stream_from_memory .add_input("INPUT", input["INPUT"]) #.add_input("SAMPLING_RATE", 16000) #.add_input("SAMPLE_BITS", 16) #.add_input("CHANNEL_COUNT", 1) .add_input("CHANNEL_COUNT", 8) #.add_input("LENGTH", 512) #.add_input("ADVANCE", 160) #.add_input("USE_WAIT", False) #.add_input("MAXIMUM_BUFFER_LENGTH", 10.0) , #node_vadzc # .add_input("INPUT", node_audio_stream_from_memory["AUDIO"]) # .add_input("CHANNEL_COUNT", 8) # .add_input("LEVEL_THRESHOLD",0) #, node_multi_fft #.add_input("INPUT", node_vadzc["OUTPUT"]) .add_input("INPUT", node_audio_stream_from_memory["AUDIO"]) # .add_input("LENGTH", 512) .add_input("WINDOW", "HANNING") # .add_input("WINDOW_LENGTH", 512) , node_sub_localization .add_input("SPEC", node_multi_fft["OUTPUT"]) , ] output.add_input("OUTPUT", node_sub_localization["OUTPUT"]) except BaseException as ex: print('error: {}'.format(ex)) return r class HARK_Main(hark.base.NetworkDef): def __init__(self): hark.base.NetworkDef.__init__(self) def build(self, network: hark.base.Network, input: hark.base.DataSourceMap, output: hark.base.DataSinkMap): try: node_publisher = network.create(hark.node.PublishData, dispatch=hark.base.RepeatDispatcher, name="Publisher") node_subscriber = network.create(hark.node.SubscribeData, name="Subscriber") loop = network.create(HARK_MAIN_LOOP, name="HARK_Main_Loop") except BaseException as ex: print(ex) try: r = [ loop .add_input("INPUT", node_publisher["OUTPUT"]), node_subscriber.add_input("INPUT", loop["OUTPUT"]) , ] except BaseException as ex: print(ex) return r def received(data): pass #print('>>>> received: {}'.format(data)) def main1(args=sys.argv[1:]): if len(args) > 0: if args[0] == '--online': use_online = True elif args[0] == '--offline': use_online = False else: raise BaseException("Unexpected argument {}, use --online or --offline".format(args[0])) else: use_online = True # default mode if use_online: # online network = hark.base.Network.from_networkdef(HARK_Main, name="HARK_Main1") publisher = network.query_nodedef("Publisher") subscriber = network.query_nodedef("Subscriber") subscriber.receive = received try: th = threading.Thread(target=network.execute) th.start() # Generate audio frames advance = 160 audio, rate = soundfile.read('rec_060_300_0_7.wav', dtype=numpy.int16) print(f"rate={rate}") frames = numpy.array([[[0] * audio.shape[1]] * advance]) if not check_version_requirements(numpy, "1.20.0"): frames = numpy.lib.stride_tricks.as_strided(audio, shape=(int(audio.shape[0]/advance), advance, audio.shape[1]), strides=(advance * audio.shape[1] * audio.strides[1], audio.shape[1]*audio.strides[1], audio.strides[1])) else: frames = numpy.lib.stride_tricks.sliding_window_view(audio, advance, axis=0)[::advance, :, :] print("DEBUG:\n frames.shape={}\n".format(frames.shape)) for t in range(len(frames)): if not th.is_alive(): break #print('<<<< send: count={}, max={}'.format(t, numpy.max(frames[t]))) #print(frames[t].shape) #publisher.push(numpy.reshape(frames[t], frames[t].shape[0] * frames[t].shape[1]).astype(numpy.int16)) publisher.push(frames[t].astype(numpy.int16)) #publisher.push(frames[t]) #time.sleep(0.01) except BaseException as ex: print(ex) except: network.stop() finally: publisher.close() if th.ident is not None: th.join() else: raise NotImplementedError("Sorry, offline mode for network-def is not implemented yet.") # offline data = numpy.zeros(160000, dtype=numpy.int16) data[(numpy.arange(len(data)) // 200) % 2 == 1] = (4096 * numpy.random.randn(80000)).astype(numpy.int16) network = hark.base.Network.from_networkdef(HARK_Recording) output = network(INPUT=data) #print(output.OUTPUT) if __name__ == '__main__': start = time.time() main1(sys.argv[1:]) end = time.time() print(f"Total_time={end-start}") #main() # end of file
以上で講習会は終了です。お疲れ様。