[Android Library] PyTorch Mobile - Speech Recognition
안드로이드 앱에서 파이토치 모델을 사용하기위해 파이토치 모바일 라이브러리를 사용해보겠습니다.
소스는 파이토치 모바일 공식 데모 앱으로 Speech Recognition 모델(Wav2Vec)을 예시로 하겠습니다.
작성일 : 21. 07. 09
관련 참조 링크
링크1. 파이토치 Speech Recognition 안드로이드 깃헙 (분석할 프로젝트)
https://github.com/pytorch/android-demo-app/tree/master/SpeechRecognition
링크2. 파이토치 모바일 공식 사이트
https://pytorch.org/mobile/android/
준비물
1. 안드로이드 스튜디오(Android Studio) 버전 4.0.1 이상
2. 파이토치 안드로이드 버전 1.9.0 이상 (build.gradle 의 dependencies 에 추가)
implementation 'org.pytorch:pytorch_android:1.9.0'
3. 모델 파일 다운로드(wav2vec2.pt)
https://drive.google.com/file/d/1RcCy3K3gDVN2Nun5IIdDbpIDbrKD-XVw/view
4. 모델 파일 안드로이드 프로젝트에 추가 (다음과 같이 assets 디렉터리 생성 후, 그 안에 모델 두기)
4. (선택) 파이토치(Pytorch) 1.9.0 & Torchaudio 0.9.0 이상
5. (선택) 파이썬(Python) 3.8 이상
UI 구성
Button, TextView 로 각각 1개. 단조로운 구성. 버튼 누를 시, 12초간 음성 인식 후 텍스트 뷰에 인식한 내용 출력
코드 분석
액티비티 생명주기 순, 프로그램 실행 순으로 주요한 내용만 살펴보겠습니다.
1. MainAvtivity - 멤버 변수
public class MainActivity extends AppCompatActivity implements Runnable {
private static final String TAG = MainActivity.class.getName();
private Module mModuleEncoder;
private TextView mTextView;
private Button mButton;
private final static int REQUEST_RECORD_AUDIO = 13;
private final static int AUDIO_LEN_IN_SECOND = 12;
private final static int SAMPLE_RATE = 16000;
private final static int RECORDING_LENGTH = SAMPLE_RATE * AUDIO_LEN_IN_SECOND;
private final static String LOG_TAG = MainActivity.class.getSimpleName();
private int mStart = 1;
private HandlerThread mTimerThread;
private Handler mTimerHandler;
private Runnable mRunnable = new Runnable() {
@Override
public void run() {
mTimerHandler.postDelayed(mRunnable, 1000);
MainActivity.this.runOnUiThread(
() -> {
mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_IN_SECOND - mStart));
mStart += 1;
});
}
};
주요 멤버부터 살펴보면,
Module mModuleEncoder : 파이토치 모델을 참조하는 객체
int AUDIO_LEN_IN_SECOND : 몇 초간 녹음할 것인지 지정
int SAMPLE_RATE : 샘플링(Sampling), 초당 몇 개의 디지털 신호를 뽑아낼지 지정
int RECORDING_LENGTH : 모델(Model)에 인풋(Input)으로 들어갈 배열의 길이.
다음은 UI 갱신을 위한 백그라운드 스레드(Background Thread)를 위해 필요한 멤버입니다.
HandlerThread mTimerThread : 핸들러 스레드 객체. 핸들러 스레드는 자체적으로 루퍼를 가지고 있다.
Handler mTimerHandler : TextView를 갱신하는 메시지를 처리할 핸들러
Runnable mRunnable : 매 1초마다 TextView 갱신 메시지(setText()) 송신 정의
- MainActivity.this.runOnUiThread() 는 UI 스레드에서 함수를 실행하도록 지원하는 메서드.
- (기본기) UI 스레드를 제외하고 다른 스레드에서 UI를 갱신할 수 없다.
눈여겨볼점은 메인 액티비티가 Runnable 인터페이스를 구현하고 있다는 점입니다.
2. MainActivity - onCreate()
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
mButton = findViewById(R.id.btnRecognize);
mTextView = findViewById(R.id.tvResult);
mButton.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_IN_SECOND));
mButton.setEnabled(false);
Thread thread = new Thread(MainActivity.this);
thread.start();
mTimerThread = new HandlerThread("Timer");
mTimerThread.start();
mTimerHandler = new Handler(mTimerThread.getLooper());
mTimerHandler.postDelayed(mRunnable, 1000);
}
});
requestMicrophonePermission();
}
여기서 살펴볼것은 버튼을 누를시, 두 가지 스레드가 실행됩니다. 첫번째로 녹음과 음성인지를 담당하는 스레드 실행.
(메인액티비티는 Runnable을 구현하므로 생성자 파라미터로 전달가능)
여기서 오버라이드 한 run() 메서드를 실행하게 됩니다.
Thread thread = new Thread(MainActivity.this);
thread.start();
두번째 스레드로 12초간의 시간 흐름을 TextView로 화면에 보여줄 HandlerThead를 실행합니다.
mTimerThread = new HandlerThread("Timer");
mTimerThread.start();
mTimerHandler = new Handler(mTimerThread.getLooper());
mTimerHandler.postDelayed(mRunnable, 1000);
mTimerHandler에 mHandlerThead의 루퍼(Looper)를 전달하여 메시지를 처리하도록하고 바로 postDelayed 메서드로 1초 지연 실행 메시지를 보냅니다. mRunnable은 또 위에서 멤버 객체 초기화 당시 postDelayed 호출하도록 구현했기때문에, 무한히 메시지를 1초마다 반복해서 보냅니다.(자기 자신을) 이것은 stopTimerThead() 메서드 호출 시 중단됩니다.
3. MainActivity - run()
@Override
public void run() {
android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_AUDIO);
int bufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT);
AudioRecord record = new AudioRecord(MediaRecorder.AudioSource.DEFAULT, SAMPLE_RATE, AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT,
bufferSize);
if (record.getState() != AudioRecord.STATE_INITIALIZED) {
Log.e(LOG_TAG, "Audio Record can't initialize!");
return;
}
record.startRecording();
long shortsRead = 0;
int recordingOffset = 0;
short[] audioBuffer = new short[bufferSize / 2];
short[] recordingBuffer = new short[RECORDING_LENGTH];
while (shortsRead < RECORDING_LENGTH) {
int numberOfShort = record.read(audioBuffer, 0, audioBuffer.length);
shortsRead += numberOfShort;
System.arraycopy(audioBuffer, 0, recordingBuffer, recordingOffset, numberOfShort);
recordingOffset += numberOfShort;
}
record.stop();
record.release();
stopTimerThread();
runOnUiThread(new Runnable() {
@Override
public void run() {
mButton.setText("Recognizing...");
}
});
float[] floatInputBuffer = new float[RECORDING_LENGTH];
// feed in float values between -1.0f and 1.0f by dividing the signed 16-bit inputs.
for (int i = 0; i < RECORDING_LENGTH; ++i) {
floatInputBuffer[i] = recordingBuffer[i] / (float)Short.MAX_VALUE;
}
final String result = recognize(floatInputBuffer);
runOnUiThread(new Runnable() {
@Override
public void run() {
showTranslationResult(result);
mButton.setEnabled(true);
mButton.setText("Start");
}
});
}
이 메서드는 Runnable 인터페이스를 구현하는 오버라이드 메서드로 버튼을 눌렀을때 백그라운드 스레드를 통해 실행되며, 내부에 while문으로 12초간 수행하도록 정의되어있습니다.
int bufferSize : 녹음한 소리에서 한번에 몇 개의 디지털 데이터를 버퍼에 담아올지 결정합니다. 여기서 AudioRecord.getMinBufferSize를 이용하는데, 이 메서드가 권장되는듯 합니다. (ToDo - 문서 확인 필요)
while문에서 위에서 정한 버퍼 사이즈만큼씩 가져와 short[] recordingBuffer 에 추가합니다. 이 배열의 길이는
RECORDING_LENGTH = SAMPLE_RATE * AUDIO_LEN_IN_SECOND;
으로 위에서 정의해 놓았습니다. 이 배열을 12초간 채우는 것입니다.
* 참고로 녹음을 스테레오로 하려면, AudioFormat.CHANNEL_IN_MONO 를 AudioFormat.CHANNEL_IN_STEREO 로 변경하고 RECORDING_LENGTH = SAMPLE_RATE * AUDIO_LEN_IN_SECOND *2 로 해야 올바르게 측정됩니다.
* 스테레오의 경우 인풋 텐서(Tensor)를 결정할때 데이터가 recordingBuffer에 다음과 같이 저장되어있음에 주의합니다.
MONO -> [ left, left, left, left, left ..... ]
STEREO -> [ left, right, left, right, left ..... ]
이렇게 추출한 12초짜리 short형 recoringBuffer 배열은 모델에 들어가기 적합하도록 표준화(Standardization)를 수행합니다. (표준화는 데이터를 -1. ~ 1. 사이의 수로 변환) 표준화한 배열은 floatInputBuffer 입니다.
이것을 recognize의 파라미터로 넣어줍니다. recognize() 메서드는 녹음한 내용을 모델이 넣고 output을 반환하는 메서드입니다.
4. MainAcitivity - recognize()
private String recognize(float[] floatInputBuffer) {
if (mModuleEncoder == null) {
final String moduleFileAbsoluteFilePath = new File(
assetFilePath(this, "wav2vec2.pt")).getAbsolutePath();
mModuleEncoder = Module.load(moduleFileAbsoluteFilePath);
}
double wav2vecinput[] = new double[RECORDING_LENGTH];
for (int n = 0; n < RECORDING_LENGTH; n++)
wav2vecinput[n] = floatInputBuffer[n];
FloatBuffer inTensorBuffer = Tensor.allocateFloatBuffer(RECORDING_LENGTH);
for (double val : wav2vecinput)
inTensorBuffer.put((float)val);
Tensor inTensor = Tensor.fromBlob(inTensorBuffer, new long[]{1, RECORDING_LENGTH});
final String result = mModuleEncoder.forward(IValue.from(inTensor)).toStr();
return result;
}
recognize() 메서드는 assets 폴더에서 모델 파일(wav2vec2.pt) 를 불러 온뒤, mModuleDencoder에 모듈을 로드합니다.
몇 가지 형변환 처리를 한 후, inpup 텐서를 정의합니다. 여기서 눈여겨 볼 메서드는 Tensor.fromBlob 입니다.
이 메서드는 주어진 버퍼를 두 번째 파라미터( new long[]{1, RECORDING_LENGTH} ) 의 Shape로 텐서를 구성하는 메서드 입니다. ( 1 x RECORDING_LENGTH )
여기서 이 모델 파일은 파이토치 모바일용으로 경량화 된 것입니다. 아래와 같이 모바일용으로 변환해야합니다.
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")
이렇게 input 텐서가 구성되면 mduleEncoder.forward()에 전달합니다. 이 메서드의 반환값은 IValue 객체인데 이를 텐서로 변환한 뒤 아래와 같이 배열로 바꿀 수 있고, 또는 위와 같이 toStr() 메서드를 통해 바로 문자열로 변환가능합니다.
(다만 이는 모델의 아웃풋 텐서 형태가 무엇인지 명확히 알고 있어야 올바르게 구현할 수 있습니다.)
iValue.toTensor().getDataAsFloatArray();
5. MainActivity 에 정의된 보조 메서드
private String assetFilePath(Context context, String assetName)
모델 파일을 불러오는 메서드입니다.
private void requestMicrophonePermission()
녹음 권한을 요청하는 메서드입니다.
@Override
protected void onDestroy()
액티비티가 종료될때 타이머 스레드 역시 메모리해제를 필요로합니다. 여기서 stopTimerThead()를 호출합니다.
protected void stopTimerThread()
각종 메모리 해제를 담당하는 메서드입니다.
끝.
'Mobile : Android > Library' 카테고리의 다른 글
[Android Library] Retrofit2 #2 - HTTP Method (0) | 2021.09.17 |
---|---|
[Android Library] Retrofit2 #1 - 레트로핏 기본 개괄 (0) | 2021.09.08 |
[Android Library] Tensorflow Lite - Sound Classification (0) | 2021.08.25 |
댓글