본문 바로가기
Mobile : Android/Library

[Android Library] PyTorch Mobile - Speech Recognition

by 신숭이 2021. 7. 9.

[Android Library] PyTorch Mobile - Speech Recognition

 

 

 

안드로이드 앱에서 파이토치 모델을 사용하기위해 파이토치 모바일 라이브러리를 사용해보겠습니다.

소스는 파이토치 모바일 공식 데모 앱으로 Speech Recognition 모델(Wav2Vec)을 예시로 하겠습니다.

작성일 : 21. 07. 09

 

관련 참조 링크


링크1. 파이토치 Speech Recognition 안드로이드 깃헙 (분석할 프로젝트)

https://github.com/pytorch/android-demo-app/tree/master/SpeechRecognition

 

pytorch/android-demo-app

PyTorch android examples of usage in applications. Contribute to pytorch/android-demo-app development by creating an account on GitHub.

github.com

 

링크2. 파이토치 모바일 공식 사이트

https://pytorch.org/mobile/android/

 

PyTorch

An open source machine learning framework that accelerates the path from research prototyping to production deployment.

pytorch.org

 

 

준비물


1. 안드로이드 스튜디오(Android Studio) 버전 4.0.1 이상

2. 파이토치 안드로이드 버전 1.9.0 이상 (build.gradle 의 dependencies 에 추가) 

implementation 'org.pytorch:pytorch_android:1.9.0'

3. 모델 파일 다운로드(wav2vec2.pt

https://drive.google.com/file/d/1RcCy3K3gDVN2Nun5IIdDbpIDbrKD-XVw/view

 

wav2vec2.pt

 

drive.google.com

4. 모델 파일 안드로이드 프로젝트에 추가 (다음과 같이 assets 디렉터리 생성 후, 그 안에 모델 두기)

4. (선택) 파이토치(Pytorch) 1.9.0 & Torchaudio 0.9.0 이상

5. (선택) 파이썬(Python) 3.8 이상

 

 

UI 구성


Button, TextView 로 각각 1개. 단조로운 구성. 버튼 누를 시, 12초간 음성 인식 후 텍스트 뷰에 인식한 내용 출력

 

코드 분석


액티비티 생명주기 순, 프로그램 실행 순으로 주요한 내용만 살펴보겠습니다. 

 

1. MainAvtivity - 멤버 변수

public class MainActivity extends AppCompatActivity implements Runnable {
    private static final String TAG = MainActivity.class.getName();

    private Module mModuleEncoder;
    private TextView mTextView;
    private Button mButton;

    private final static int REQUEST_RECORD_AUDIO = 13;
    private final static int AUDIO_LEN_IN_SECOND = 12;
    private final static int SAMPLE_RATE = 16000;
    private final static int RECORDING_LENGTH = SAMPLE_RATE * AUDIO_LEN_IN_SECOND;

    private final static String LOG_TAG = MainActivity.class.getSimpleName();

    private int mStart = 1;
    private HandlerThread mTimerThread;
    private Handler mTimerHandler;
    private Runnable mRunnable = new Runnable() {
        @Override
        public void run() {
            mTimerHandler.postDelayed(mRunnable, 1000);

            MainActivity.this.runOnUiThread(
                    () -> {
                        mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_IN_SECOND - mStart));
                        mStart += 1;
                    });
        }
    };

주요 멤버부터 살펴보면,

Module mModuleEncoder : 파이토치 모델을 참조하는 객체 

int AUDIO_LEN_IN_SECOND : 몇 초간 녹음할 것인지 지정

int SAMPLE_RATE : 샘플링(Sampling), 초당 몇 개의 디지털 신호를 뽑아낼지 지정

int RECORDING_LENGTH : 모델(Model)에 인풋(Input)으로 들어갈 배열의 길이. 

 

다음은 UI 갱신을 위한 백그라운드 스레드(Background Thread)를 위해 필요한 멤버입니다.

HandlerThread mTimerThread : 핸들러 스레드 객체. 핸들러 스레드는 자체적으로 루퍼를 가지고 있다.

Handler mTimerHandler : TextView를 갱신하는 메시지를 처리할 핸들러

Runnable mRunnable : 매 1초마다 TextView 갱신 메시지(setText()) 송신 정의

 - MainActivity.this.runOnUiThread() 는 UI 스레드에서 함수를 실행하도록 지원하는 메서드.

 - (기본기) UI 스레드를 제외하고 다른 스레드에서 UI를 갱신할 수 없다.

 

눈여겨볼점은 메인 액티비티가 Runnable 인터페이스를 구현하고 있다는 점입니다.

 

2. MainActivity - onCreate() 

@Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        mButton = findViewById(R.id.btnRecognize);
        mTextView = findViewById(R.id.tvResult);

        mButton.setOnClickListener(new View.OnClickListener() {
            public void onClick(View v) {
                mButton.setText(String.format("Listening - %ds left", AUDIO_LEN_IN_SECOND));
                mButton.setEnabled(false);

                Thread thread = new Thread(MainActivity.this);
                thread.start();

                mTimerThread = new HandlerThread("Timer");
                mTimerThread.start();
                mTimerHandler = new Handler(mTimerThread.getLooper());
                mTimerHandler.postDelayed(mRunnable, 1000);

            }
        });
        requestMicrophonePermission();
    }

 

여기서 살펴볼것은 버튼을 누를시, 두 가지 스레드가 실행됩니다. 첫번째로 녹음과 음성인지를 담당하는 스레드 실행.

(메인액티비티는 Runnable을 구현하므로 생성자 파라미터로 전달가능)

여기서 오버라이드 한 run() 메서드를 실행하게 됩니다.

Thread thread = new Thread(MainActivity.this);
thread.start();

 

두번째 스레드로 12초간의 시간 흐름을 TextView로 화면에 보여줄 HandlerThead를 실행합니다.

mTimerThread = new HandlerThread("Timer");
mTimerThread.start();
mTimerHandler = new Handler(mTimerThread.getLooper());
mTimerHandler.postDelayed(mRunnable, 1000);

 

mTimerHandlermHandlerThead의 루퍼(Looper)를 전달하여 메시지를 처리하도록하고 바로 postDelayed 메서드로 1초 지연 실행 메시지를 보냅니다. mRunnable은 또 위에서 멤버 객체 초기화 당시 postDelayed 호출하도록 구현했기때문에, 무한히 메시지를 1초마다 반복해서 보냅니다.(자기 자신을) 이것은 stopTimerThead() 메서드 호출 시 중단됩니다.

 

 

3. MainActivity - run()

@Override
public void run() {
        android.os.Process.setThreadPriority(android.os.Process.THREAD_PRIORITY_AUDIO);

        int bufferSize = AudioRecord.getMinBufferSize(SAMPLE_RATE, AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT);
        AudioRecord record = new AudioRecord(MediaRecorder.AudioSource.DEFAULT, SAMPLE_RATE, AudioFormat.CHANNEL_IN_MONO, AudioFormat.ENCODING_PCM_16BIT,
                bufferSize);

        if (record.getState() != AudioRecord.STATE_INITIALIZED) {
            Log.e(LOG_TAG, "Audio Record can't initialize!");
            return;
        }
        record.startRecording();

        long shortsRead = 0;
        int recordingOffset = 0;
        short[] audioBuffer = new short[bufferSize / 2];
        short[] recordingBuffer = new short[RECORDING_LENGTH];

        while (shortsRead < RECORDING_LENGTH) {
            int numberOfShort = record.read(audioBuffer, 0, audioBuffer.length);
            shortsRead += numberOfShort;
            System.arraycopy(audioBuffer, 0, recordingBuffer, recordingOffset, numberOfShort);
            recordingOffset += numberOfShort;
        }

        record.stop();
        record.release();
        stopTimerThread();

        runOnUiThread(new Runnable() {
            @Override
            public void run() {
                mButton.setText("Recognizing...");
            }
        });

        float[] floatInputBuffer = new float[RECORDING_LENGTH];

        // feed in float values between -1.0f and 1.0f by dividing the signed 16-bit inputs.
        for (int i = 0; i < RECORDING_LENGTH; ++i) {
            floatInputBuffer[i] = recordingBuffer[i] / (float)Short.MAX_VALUE;
        }

        final String result = recognize(floatInputBuffer);

        runOnUiThread(new Runnable() {
            @Override
            public void run() {
                showTranslationResult(result);
                mButton.setEnabled(true);
                mButton.setText("Start");
            }
        });
    }

이 메서드는 Runnable 인터페이스를 구현하는 오버라이드 메서드로 버튼을 눌렀을때 백그라운드 스레드를 통해 실행되며, 내부에 while문으로 12초간 수행하도록 정의되어있습니다.

 

int bufferSize : 녹음한 소리에서 한번에 몇 개의 디지털 데이터를 버퍼에 담아올지 결정합니다. 여기서 AudioRecord.getMinBufferSize를 이용하는데, 이 메서드가 권장되는듯 합니다. (ToDo - 문서 확인 필요)

 

while문에서 위에서 정한 버퍼 사이즈만큼씩 가져와 short[] recordingBuffer 에 추가합니다. 이 배열의 길이는

RECORDING_LENGTH = SAMPLE_RATE * AUDIO_LEN_IN_SECOND;

으로 위에서 정의해 놓았습니다. 이 배열을 12초간 채우는 것입니다.

 

* 참고로 녹음을 스테레오로 하려면, AudioFormat.CHANNEL_IN_MONOAudioFormat.CHANNEL_IN_STEREO 로 변경하고 RECORDING_LENGTH = SAMPLE_RATE * AUDIO_LEN_IN_SECOND *2 로 해야 올바르게 측정됩니다.

 

* 스테레오의 경우 인풋 텐서(Tensor)를 결정할때  데이터가 recordingBuffer에 다음과 같이 저장되어있음에 주의합니다.

 MONO  ->  [ left, left, left, left, left ..... ]

 STEREO ->  [ left, right, left, right, left ..... ]

 

이렇게 추출한 12초짜리 short형 recoringBuffer 배열은 모델에 들어가기 적합하도록 표준화(Standardization)를 수행합니다. (표준화는 데이터를 -1. ~ 1.  사이의 수로 변환) 표준화한 배열은 floatInputBuffer 입니다.

 

이것을 recognize의 파라미터로 넣어줍니다. recognize() 메서드는 녹음한 내용을 모델이 넣고 output을 반환하는 메서드입니다.

 

4. MainAcitivity - recognize()

private String recognize(float[] floatInputBuffer) {
        if (mModuleEncoder == null) {
            final String moduleFileAbsoluteFilePath = new File(
                    assetFilePath(this, "wav2vec2.pt")).getAbsolutePath();
            mModuleEncoder = Module.load(moduleFileAbsoluteFilePath);
        }

        double wav2vecinput[] = new double[RECORDING_LENGTH];
        for (int n = 0; n < RECORDING_LENGTH; n++)
            wav2vecinput[n] = floatInputBuffer[n];

        FloatBuffer inTensorBuffer = Tensor.allocateFloatBuffer(RECORDING_LENGTH);
        for (double val : wav2vecinput)
            inTensorBuffer.put((float)val);

        Tensor inTensor = Tensor.fromBlob(inTensorBuffer, new long[]{1, RECORDING_LENGTH});
        final String result = mModuleEncoder.forward(IValue.from(inTensor)).toStr();

        return result;
    }

recognize() 메서드는 assets 폴더에서 모델 파일(wav2vec2.pt) 를 불러 온뒤, mModuleDencoder에 모듈을 로드합니다.

몇 가지 형변환 처리를 한 후, inpup 텐서를 정의합니다. 여기서 눈여겨 볼 메서드는 Tensor.fromBlob 입니다.

이 메서드는 주어진 버퍼를 두 번째 파라미터( new long[]{1, RECORDING_LENGTH} ) 의 Shape로 텐서를 구성하는 메서드 입니다. ( 1 x RECORDING_LENGTH )

 

여기서 이 모델 파일은 파이토치 모바일용으로 경량화 된 것입니다. 아래와 같이 모바일용으로 변환해야합니다.

import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")

 

이렇게 input 텐서가 구성되면 mduleEncoder.forward()에 전달합니다. 이 메서드의 반환값은 IValue 객체인데 이를 텐서로 변환한 뒤 아래와 같이 배열로 바꿀 수 있고, 또는 위와 같이 toStr() 메서드를 통해 바로 문자열로 변환가능합니다.

(다만 이는 모델의 아웃풋 텐서 형태가 무엇인지 명확히 알고 있어야 올바르게 구현할 수 있습니다.)

iValue.toTensor().getDataAsFloatArray();

 

5. MainActivity 에 정의된 보조 메서드

private String assetFilePath(Context context, String assetName)

모델 파일을 불러오는 메서드입니다.

 

private void requestMicrophonePermission()

녹음 권한을 요청하는 메서드입니다.

 

@Override
protected void onDestroy()

액티비티가 종료될때 타이머 스레드 역시 메모리해제를 필요로합니다. 여기서 stopTimerThead()를 호출합니다.

 

protected void stopTimerThread()

각종 메모리 해제를 담당하는 메서드입니다.

 

 

 

 

끝.

 

댓글