본문 바로가기
Mobile : Android/Library

[Android Library] Tensorflow Lite - Sound Classification

by 신숭이 2021. 8. 25.

[Android Library] Tensorflow Lite - Sound Classification

 

 

 

경량화 기술의 발달로 모바일에서 AI 모델을 돌릴 수 있는 기술들이 많이 등장하고 있다.

지난번에 PyTorch Mobile 을 이용한 음성 인식 모델 사용법에 이어, 좀 더 간편하게 사용할 수 있는 Tensorflow Lite 라이브러리를 소개하고자 한다.예시는 YAMNet 모델 이용하여 521가지의 소리를 분류해 보겠다.

 

언어는 Java를 이용하겠다. Kotlin 은 구글에서 제공하는 공식 레퍼런스를 참조하면 된다.

(프로젝트가 Java 여서 테스트 코드를 만들다 보니, 부득이하게 Java를 예시로 들겠다)

 

UI 는 없이 Log로 찍어서 확인하겠다.

 

 

 

준비

 

1. Gradle 앱수준의 dependencies

implementation 'org.tensorflow:tensorflow-lite-task-audio:0.2.0'

2. AndroidMenifest.xml에 다음과 같이 권한 추가

<uses-permission android:name="android.permission.RECORD_AUDIO" />

3. Asset 폴더 추가

4. 모델 파일(.tflite) assets 폴더에 다운로드 (YAMNet) - TensorFlow Hub

https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1

 

TensorFlow Hub

 

tfhub.dev

 

부록 : Kotlin 예제 코드

https://github.com/tensorflow/examples/blob/master/lite/examples/sound_classification/android/app/src/main/java/org/tensorflow/lite/examples/soundclassifier/MainActivity.kt

 

GitHub - tensorflow/examples: TensorFlow examples

TensorFlow examples. Contribute to tensorflow/examples development by creating an account on GitHub.

github.com

 

부록 : Java 예제 코드

https://github.com/tnqkr98/audio_classification_java/blob/master/app/src/main/java/com/tnqkr98/tensoflowlitejava/MainActivity.java

 

GitHub - tnqkr98/audio_classification_java: Audio Classfication (TensorflowLite + Android + Java)

Audio Classfication (TensorflowLite + Android + Java) - GitHub - tnqkr98/audio_classification_java: Audio Classfication (TensorflowLite + Android + Java)

github.com

 

 

MainActivity - onCreate()

 

코드는 Kotlin 예제 코드에서 필요한 부분만 Java로 뽑아본 것이다.

public class MainActivity extends AppCompatActivity {

    private static final String MODEL_FILE = "yamnet.tflite";
    private static final float MINIMUM_DISPLAY_THRESHOLD = 0.3f;

    private AudioClassifier mAudioClassifier;
    private AudioRecord mAudioRecord;
    private long classficationInterval = 500;       // 0.5 sec (샘플링 주기)
    private Handler mHandler;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        HandlerThread handlerThread = new HandlerThread("backgroundThread");
        handlerThread.start();
        mHandler = HandlerCompat.createAsync(handlerThread.getLooper());

        if(Build.VERSION.SDK_INT >= Build.VERSION_CODES.M)
            requestPermissions(new String[]{android.Manifest.permission.RECORD_AUDIO}, 4);

        startAudioClassification();
    }

상수 부터 살펴보면 MODEL_FILE 은 .tflite 파일명을 기입하고, MINIMUM_DISPLAY_THRESHOLD는 원래 예제에서 리사이클러뷰를 이용해, 스코어가 높은 Label 만 출력하는 구조여서 스코어가 0.3 을 넘는 것만 출력하겠다는 뜻이다. (그 이하는 무수히 많기때문에 0.3 정도가 적당한 선인것으로 보인다)

 

AudioClassifier 는 말 그대로 모델의 참조 인스턴스이고, AudioRecord 는 classficationInterval 만큼(밀리초) 씩 녹음한 내용을 담는 객체이다. (변수명은 오타이다)

 

HandlerThread handlerThread = new HandlerThread("backgroundThread");
handlerThread.start();
mHandler = HandlerCompat.createAsync(handlerThread.getLooper());

 

Handler가 Runnable 메시지를 무한히 처리하기 위해서는 Looper가 필요한데 HandlerThead의 루퍼를 이용하도록 하는 구문이다.

 

이후 startAudioClassification() 을 바로 호출하여, 소리 분류를 시작하는데 메서드의 구현부를 살펴보겠다.

 

 

 

MainActivity - startAudioClassification()

private void startAudioClassification(){
        if(mAudioClassifier != null) return;

        try {
            AudioClassifier classifier = AudioClassifier.createFromFile(this, MODEL_FILE);
            TensorAudio audioTensor = classifier.createInputTensorAudio();

            AudioRecord record = classifier.createAudioRecord();
            record.startRecording();

            Runnable run = new Runnable() {
                @Override
                public void run() {
                    audioTensor.load(record);
                    List<Classifications> output = classifier.classify(audioTensor);
                    List<Category> filterModelOutput = output.get(0).getCategories();
                    for(Category c : filterModelOutput) {
                        if (c.getScore() > MINIMUM_DISPLAY_THRESHOLD)
                            Log.d("tensorAudio_java", " label : " + c.getLabel() + " score : " + c.getScore());
                    }

                    mHandler.postDelayed(this,classficationInterval);
                }
            };

            mHandler.post(run);
            mAudioClassifier = classifier;
            mAudioRecord = record;
        }catch (IOException e){
            e.printStackTrace();
        }
    }

 

AudioClassifier classifier = AudioClassifier.createFromFile(this, MODEL_FILE);
TensorAudio audioTensor = classifier.createInputTensorAudio();

 

여기서 오디오 분류기의 인스턴스를 가져오고, 그 분류기를 기반으로 인풋 오디오 텐서를 구성한다.

 

이 후 Runnable 객체를 구성한다. 객체의 run() 메서드를 살펴보면

녹음 내용을 기반으로 텐서를 구성하고

 

List<Classifications> output = classifier.classify(audioTensor);

 

텐서를 분류기에 넣고 소리를 분류한다. Output은 코드와 같이 구성하여 for문으로 확인해볼 수 있다.

 

mHandler.postDelayed(this,classficationInterval);

0.5초간 계속 메시지를 처리하도록 스레드를 구성한다.

 

 

결과 Log

 

 

 

 

코틀린 예제 

/*
 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.tensorflow.lite.examples.soundclassifier

import android.Manifest
import android.content.pm.PackageManager
import android.media.AudioRecord
import android.os.Build
import android.os.Bundle
import android.os.Handler
import android.os.HandlerThread
import android.util.Log
import android.view.WindowManager
import androidx.annotation.RequiresApi
import androidx.appcompat.app.AppCompatActivity
import androidx.core.content.ContextCompat
import androidx.core.os.HandlerCompat
import org.tensorflow.lite.examples.soundclassifier.databinding.ActivityMainBinding
import org.tensorflow.lite.task.audio.classifier.AudioClassifier


class MainActivity : AppCompatActivity() {
  private val probabilitiesAdapter by lazy { ProbabilitiesAdapter() }

  private var audioClassifier: AudioClassifier? = null
  private var audioRecord: AudioRecord? = null
  private var classificationInterval = 500L // how often should classification run in milli-secs
  private lateinit var handler: Handler // background thread handler to run classification

  override fun onCreate(savedInstanceState: Bundle?) {
    super.onCreate(savedInstanceState)

    val binding = ActivityMainBinding.inflate(layoutInflater)
    setContentView(binding.root)

    with(binding) {
      recyclerView.apply {
        setHasFixedSize(false)
        adapter = probabilitiesAdapter
      }

      // Input switch to turn on/off classification
      keepScreenOn(inputSwitch.isChecked)
      inputSwitch.setOnCheckedChangeListener { _, isChecked ->
        if (isChecked) startAudioClassification() else stopAudioClassification()
        keepScreenOn(isChecked)
      }

      // Slider which control how often the classification task should run
      classificationIntervalSlider.value = classificationInterval.toFloat()
      classificationIntervalSlider.setLabelFormatter { value: Float ->
        "${value.toInt()} ms"
      }
      classificationIntervalSlider.addOnChangeListener { _, value, _ ->
        classificationInterval = value.toLong()
        stopAudioClassification()
        startAudioClassification()
      }
    }

    // Create a handler to run classification in a background thread
    val handlerThread = HandlerThread("backgroundThread")
    handlerThread.start()
    handler = HandlerCompat.createAsync(handlerThread.looper)

    // Request microphone permission and start running classification
    if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
      requestMicrophonePermission()
    } else {
      startAudioClassification()
    }

  }

  private fun startAudioClassification() {
    // If the audio classifier is initialized and running, do nothing.
    if (audioClassifier != null) return;

    // Initialize the audio classifier
    val classifier = AudioClassifier.createFromFile(this, MODEL_FILE)
    val audioTensor = classifier.createInputTensorAudio()

    // Initialize the audio recorder
    val record = classifier.createAudioRecord()
    record.startRecording()

    // Define the classification runnable
    val run = object : Runnable {
      override fun run() {
        val startTime = System.currentTimeMillis()

        // Load the latest audio sample
        audioTensor.load(record)
        val output = classifier.classify(audioTensor)

        // Filter out results above a certain threshold, and sort them descendingly
        val filteredModelOutput = output[0].categories.filter {
          it.score > MINIMUM_DISPLAY_THRESHOLD
        }.sortedBy {
          -it.score
        }

        val finishTime = System.currentTimeMillis()

        Log.d(TAG, "Latency = ${finishTime - startTime}ms")

        // Updating the UI
        runOnUiThread {
          probabilitiesAdapter.categoryList = filteredModelOutput
          probabilitiesAdapter.notifyDataSetChanged()
        }

        // Rerun the classification after a certain interval
        handler.postDelayed(this, classificationInterval)
      }
    }

    // Start the classification process
    handler.post(run)

    // Save the instances we just created for use later
    audioClassifier = classifier
    audioRecord = record
  }

  private fun stopAudioClassification() {
    handler.removeCallbacksAndMessages(null)
    audioRecord?.stop()
    audioRecord = null
    audioClassifier = null
  }

  override fun onTopResumedActivityChanged(isTopResumedActivity: Boolean) {
    // Handles "top" resumed event on multi-window environment
    if (isTopResumedActivity) {
      startAudioClassification()
    } else {
      stopAudioClassification()
    }
  }

  override fun onRequestPermissionsResult(
          requestCode: Int,
          permissions: Array<out String>,
          grantResults: IntArray
  ) {
    if (requestCode == REQUEST_RECORD_AUDIO) {
      if (grantResults.isNotEmpty() && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
        Log.i(TAG, "Audio permission granted :)")
        startAudioClassification()
      } else {
        Log.e(TAG, "Audio permission not granted :(")
      }
    }
  }

  @RequiresApi(Build.VERSION_CODES.M)
  private fun requestMicrophonePermission() {
    if (ContextCompat.checkSelfPermission(
                    this,
                    Manifest.permission.RECORD_AUDIO
            ) == PackageManager.PERMISSION_GRANTED
    ) {
      startAudioClassification()
    } else {
      requestPermissions(arrayOf(Manifest.permission.RECORD_AUDIO), REQUEST_RECORD_AUDIO)
    }
  }

  private fun keepScreenOn(enable: Boolean) =
    if (enable) {
      window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)
    } else {
      window.clearFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)
    }

  companion object {
    const val REQUEST_RECORD_AUDIO = 1337
    private const val TAG = "AudioDemo"
    private const val MODEL_FILE = "yamnet.tflite"
    private const val MINIMUM_DISPLAY_THRESHOLD: Float = 0.3f
  }
}

댓글