[Android Library] Tensorflow Lite - Sound Classification
경량화 기술의 발달로 모바일에서 AI 모델을 돌릴 수 있는 기술들이 많이 등장하고 있다.
지난번에 PyTorch Mobile 을 이용한 음성 인식 모델 사용법에 이어, 좀 더 간편하게 사용할 수 있는 Tensorflow Lite 라이브러리를 소개하고자 한다.예시는 YAMNet 모델 이용하여 521가지의 소리를 분류해 보겠다.
언어는 Java를 이용하겠다. Kotlin 은 구글에서 제공하는 공식 레퍼런스를 참조하면 된다.
(프로젝트가 Java 여서 테스트 코드를 만들다 보니, 부득이하게 Java를 예시로 들겠다)
UI 는 없이 Log로 찍어서 확인하겠다.
준비
1. Gradle 앱수준의 dependencies
implementation 'org.tensorflow:tensorflow-lite-task-audio:0.2.0'
2. AndroidMenifest.xml에 다음과 같이 권한 추가
<uses-permission android:name="android.permission.RECORD_AUDIO" />
3. Asset 폴더 추가
4. 모델 파일(.tflite) assets 폴더에 다운로드 (YAMNet) - TensorFlow Hub
https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1
부록 : Kotlin 예제 코드
부록 : Java 예제 코드
MainActivity - onCreate()
코드는 Kotlin 예제 코드에서 필요한 부분만 Java로 뽑아본 것이다.
public class MainActivity extends AppCompatActivity {
private static final String MODEL_FILE = "yamnet.tflite";
private static final float MINIMUM_DISPLAY_THRESHOLD = 0.3f;
private AudioClassifier mAudioClassifier;
private AudioRecord mAudioRecord;
private long classficationInterval = 500; // 0.5 sec (샘플링 주기)
private Handler mHandler;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
HandlerThread handlerThread = new HandlerThread("backgroundThread");
handlerThread.start();
mHandler = HandlerCompat.createAsync(handlerThread.getLooper());
if(Build.VERSION.SDK_INT >= Build.VERSION_CODES.M)
requestPermissions(new String[]{android.Manifest.permission.RECORD_AUDIO}, 4);
startAudioClassification();
}
상수 부터 살펴보면 MODEL_FILE 은 .tflite 파일명을 기입하고, MINIMUM_DISPLAY_THRESHOLD는 원래 예제에서 리사이클러뷰를 이용해, 스코어가 높은 Label 만 출력하는 구조여서 스코어가 0.3 을 넘는 것만 출력하겠다는 뜻이다. (그 이하는 무수히 많기때문에 0.3 정도가 적당한 선인것으로 보인다)
AudioClassifier 는 말 그대로 모델의 참조 인스턴스이고, AudioRecord 는 classficationInterval 만큼(밀리초) 씩 녹음한 내용을 담는 객체이다. (변수명은 오타이다)
HandlerThread handlerThread = new HandlerThread("backgroundThread");
handlerThread.start();
mHandler = HandlerCompat.createAsync(handlerThread.getLooper());
Handler가 Runnable 메시지를 무한히 처리하기 위해서는 Looper가 필요한데 HandlerThead의 루퍼를 이용하도록 하는 구문이다.
이후 startAudioClassification() 을 바로 호출하여, 소리 분류를 시작하는데 메서드의 구현부를 살펴보겠다.
MainActivity - startAudioClassification()
private void startAudioClassification(){
if(mAudioClassifier != null) return;
try {
AudioClassifier classifier = AudioClassifier.createFromFile(this, MODEL_FILE);
TensorAudio audioTensor = classifier.createInputTensorAudio();
AudioRecord record = classifier.createAudioRecord();
record.startRecording();
Runnable run = new Runnable() {
@Override
public void run() {
audioTensor.load(record);
List<Classifications> output = classifier.classify(audioTensor);
List<Category> filterModelOutput = output.get(0).getCategories();
for(Category c : filterModelOutput) {
if (c.getScore() > MINIMUM_DISPLAY_THRESHOLD)
Log.d("tensorAudio_java", " label : " + c.getLabel() + " score : " + c.getScore());
}
mHandler.postDelayed(this,classficationInterval);
}
};
mHandler.post(run);
mAudioClassifier = classifier;
mAudioRecord = record;
}catch (IOException e){
e.printStackTrace();
}
}
AudioClassifier classifier = AudioClassifier.createFromFile(this, MODEL_FILE);
TensorAudio audioTensor = classifier.createInputTensorAudio();
여기서 오디오 분류기의 인스턴스를 가져오고, 그 분류기를 기반으로 인풋 오디오 텐서를 구성한다.
이 후 Runnable 객체를 구성한다. 객체의 run() 메서드를 살펴보면
녹음 내용을 기반으로 텐서를 구성하고
List<Classifications> output = classifier.classify(audioTensor);
텐서를 분류기에 넣고 소리를 분류한다. Output은 코드와 같이 구성하여 for문으로 확인해볼 수 있다.
mHandler.postDelayed(this,classficationInterval);
0.5초간 계속 메시지를 처리하도록 스레드를 구성한다.
결과 Log
코틀린 예제
/*
* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.tensorflow.lite.examples.soundclassifier
import android.Manifest
import android.content.pm.PackageManager
import android.media.AudioRecord
import android.os.Build
import android.os.Bundle
import android.os.Handler
import android.os.HandlerThread
import android.util.Log
import android.view.WindowManager
import androidx.annotation.RequiresApi
import androidx.appcompat.app.AppCompatActivity
import androidx.core.content.ContextCompat
import androidx.core.os.HandlerCompat
import org.tensorflow.lite.examples.soundclassifier.databinding.ActivityMainBinding
import org.tensorflow.lite.task.audio.classifier.AudioClassifier
class MainActivity : AppCompatActivity() {
private val probabilitiesAdapter by lazy { ProbabilitiesAdapter() }
private var audioClassifier: AudioClassifier? = null
private var audioRecord: AudioRecord? = null
private var classificationInterval = 500L // how often should classification run in milli-secs
private lateinit var handler: Handler // background thread handler to run classification
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
val binding = ActivityMainBinding.inflate(layoutInflater)
setContentView(binding.root)
with(binding) {
recyclerView.apply {
setHasFixedSize(false)
adapter = probabilitiesAdapter
}
// Input switch to turn on/off classification
keepScreenOn(inputSwitch.isChecked)
inputSwitch.setOnCheckedChangeListener { _, isChecked ->
if (isChecked) startAudioClassification() else stopAudioClassification()
keepScreenOn(isChecked)
}
// Slider which control how often the classification task should run
classificationIntervalSlider.value = classificationInterval.toFloat()
classificationIntervalSlider.setLabelFormatter { value: Float ->
"${value.toInt()} ms"
}
classificationIntervalSlider.addOnChangeListener { _, value, _ ->
classificationInterval = value.toLong()
stopAudioClassification()
startAudioClassification()
}
}
// Create a handler to run classification in a background thread
val handlerThread = HandlerThread("backgroundThread")
handlerThread.start()
handler = HandlerCompat.createAsync(handlerThread.looper)
// Request microphone permission and start running classification
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
requestMicrophonePermission()
} else {
startAudioClassification()
}
}
private fun startAudioClassification() {
// If the audio classifier is initialized and running, do nothing.
if (audioClassifier != null) return;
// Initialize the audio classifier
val classifier = AudioClassifier.createFromFile(this, MODEL_FILE)
val audioTensor = classifier.createInputTensorAudio()
// Initialize the audio recorder
val record = classifier.createAudioRecord()
record.startRecording()
// Define the classification runnable
val run = object : Runnable {
override fun run() {
val startTime = System.currentTimeMillis()
// Load the latest audio sample
audioTensor.load(record)
val output = classifier.classify(audioTensor)
// Filter out results above a certain threshold, and sort them descendingly
val filteredModelOutput = output[0].categories.filter {
it.score > MINIMUM_DISPLAY_THRESHOLD
}.sortedBy {
-it.score
}
val finishTime = System.currentTimeMillis()
Log.d(TAG, "Latency = ${finishTime - startTime}ms")
// Updating the UI
runOnUiThread {
probabilitiesAdapter.categoryList = filteredModelOutput
probabilitiesAdapter.notifyDataSetChanged()
}
// Rerun the classification after a certain interval
handler.postDelayed(this, classificationInterval)
}
}
// Start the classification process
handler.post(run)
// Save the instances we just created for use later
audioClassifier = classifier
audioRecord = record
}
private fun stopAudioClassification() {
handler.removeCallbacksAndMessages(null)
audioRecord?.stop()
audioRecord = null
audioClassifier = null
}
override fun onTopResumedActivityChanged(isTopResumedActivity: Boolean) {
// Handles "top" resumed event on multi-window environment
if (isTopResumedActivity) {
startAudioClassification()
} else {
stopAudioClassification()
}
}
override fun onRequestPermissionsResult(
requestCode: Int,
permissions: Array<out String>,
grantResults: IntArray
) {
if (requestCode == REQUEST_RECORD_AUDIO) {
if (grantResults.isNotEmpty() && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
Log.i(TAG, "Audio permission granted :)")
startAudioClassification()
} else {
Log.e(TAG, "Audio permission not granted :(")
}
}
}
@RequiresApi(Build.VERSION_CODES.M)
private fun requestMicrophonePermission() {
if (ContextCompat.checkSelfPermission(
this,
Manifest.permission.RECORD_AUDIO
) == PackageManager.PERMISSION_GRANTED
) {
startAudioClassification()
} else {
requestPermissions(arrayOf(Manifest.permission.RECORD_AUDIO), REQUEST_RECORD_AUDIO)
}
}
private fun keepScreenOn(enable: Boolean) =
if (enable) {
window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)
} else {
window.clearFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON)
}
companion object {
const val REQUEST_RECORD_AUDIO = 1337
private const val TAG = "AudioDemo"
private const val MODEL_FILE = "yamnet.tflite"
private const val MINIMUM_DISPLAY_THRESHOLD: Float = 0.3f
}
}
'Mobile : Android > Library' 카테고리의 다른 글
[Android Library] Retrofit2 #2 - HTTP Method (0) | 2021.09.17 |
---|---|
[Android Library] Retrofit2 #1 - 레트로핏 기본 개괄 (0) | 2021.09.08 |
[Android Library] PyTorch Mobile - Speech Recognition (0) | 2021.07.09 |
댓글