@senspond

>

개발>모바일

딥러닝 모델을 플러터로 모바일기기에 임베딩하기

등록일시 : 2024-11-14 (목) 12:12
업데이트 : 2025-01-22 (수) 11:06
오늘 조회수 : 4
총 조회수 : 611

    딥러닝 모델을 플러터로 모바일기기에 임베딩하기

    오랜만에 블로그에 글을 쓴다.

    MNIST 데이터셋으로 학습한 손글씨 분류 모델을 Flutter로 임베딩 해봤다.







    준비

    TensorFlow Lite (TFLite)로 변환을 해야 모바일 기기에 임베딩을 할 수가 있다. tflite는 TensorFlow경량 버전으로, 모바일 및 임베디드 장치에서 딥러닝 모델을 추론 할 수 있도록 설계된 프레임워크이다. TensorFlow Lite는 경량성, 효율성, 빠른 추론을 목표로 하며, 스마트폰, IoT 장치, 마이크로 컨트롤러와 같은 리소스가 제한된 장치에서도 머신러닝 추론을 수행할 수 있게 한다. 양자화 기법 등을 통해 모델 사이즈를 줄일 수 있고, 오직 추론만 가능하도록 설계되어 있다.


    파이토치 → tflite

    먼저 ONNX 로 변환한다음, 텐서플로우 모델로 변환하고 최종적으로 텐서플로우 라이트(tflite)로 변환해야 한다. 변환과정이 다소 번거롭고 복잡하기에 변환하는 과정에서 네트워크 정보가 손실되는 경우들이 생긴다. 파이토치, onnx, 텐서플로우, 텐서플로우 라이트 모두 고유한 연산자를 사용하는데, 일부 특정 프레임워크에서 사용되는 비표준 연산자 같은 경우 제대로 맵핑이 되지 않을 수가 있다. 딥러닝 연구를 할 때는 파이토치를 많이 사용하지만, tflite로 모바일 기기 임베딩을 할 때는 텐서플로우로 개발을 하는 것이 정신건강에 좋을 것 같다.


    파이토치 → onnx

    import torch
    import torch.onnx
    
    # PyTorch 모델 정의 (예시 모델)
    class SimpleModel(torch.nn.Module):
        def __init__(self):
            super(SimpleModel, self).__init__()
            self.fc = torch.nn.Linear(10, 5)
    
        def forward(self, x):
            return self.fc(x)
    
    # 모델 인스턴스 생성 및 저장
    model = SimpleModel()
    model.eval()
    
    # 더미 입력 정의 (모델의 입력 형식과 일치해야 함)
    dummy_input = torch.randn(1, 10)
    
    # 모델을 ONNX 형식으로 저장
    onnx_model_path = "simple_model.onnx"
    torch.onnx.export(
        model, dummy_input, onnx_model_path,
        input_names=['input'], output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )
    
    print(f"ONNX 모델이 저장되었습니다: {onnx_model_path}")
    


    onnx → tensorflow

    import onnx
    from onnx_tf.backend import prepare
    
    # ONNX 모델 로드
    onnx_model = onnx.load(onnx_model_path)
    
    # TensorFlow 모델로 변환
    tf_rep = prepare(onnx_model)
    
    # 변환된 TensorFlow 모델 저장
    tf_model_path = "simple_model_tf"
    tf_rep.export_graph(tf_model_path)
    print(f"TensorFlow 모델이 저장되었습니다: {tf_model_path}")
    


    tensorflow → tflite

    import tensorflow as tf
    
    # TensorFlow 모델 로드
    saved_model_dir = tf_model_path
    
    # TFLite Converter 사용
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
    
    # 모델 최적화 옵션 (선택사항)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    
    # TFLite 모델로 변환
    tflite_model = converter.convert()
    
    # TFLite 모델 저장
    tflite_model_path = "simple_model.tflite"
    with open(tflite_model_path, "wb") as f:
        f.write(tflite_model)
    
    print(f"TFLite 모델이 저장되었습니다: {tflite_model_path}")
    


    텐서플로우 → tflite

    import tensorflow as tf
    
    # TensorFlow 모델 로드
    saved_model_dir = tf_model_path
    
    # TFLite Converter 사용
    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
    
    # 모델 최적화 옵션 (선택사항)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    
    # TFLite 모델로 변환
    tflite_model = converter.convert()
    
    # TFLite 모델 저장
    tflite_model_path = "simple_model.tflite"
    with open(tflite_model_path, "wb") as f:
        f.write(tflite_model)
    
    print(f"TFLite 모델이 저장되었습니다: {tflite_model_path}")


    케라스→ tflite

    .h5 혹은 .keras 형태로 저장된 케라스 모델을 텐서플로우 모델로 변환 후 텐서플로우 라이트(tflite) 로 변환해야 한다.

    import tensorflow as tf
    
    # Keras 모델 (.h5 /.keras 파일) 로드
    h5_model_path = "my_model.h5"  # Keras 모델이 저장된 경로
    keras_model = tf.keras.models.load_model(h5_model_path)
    
    # TFLiteConverter를 사용하여 Keras 모델을 TFLite 모델로 변환
    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
    
    # 모델 최적화 옵션 설정 (선택사항)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    
    # TFLite 모델 변환 수행
    tflite_model = converter.convert()
    
    # TFLite 모델을 파일로 저장
    tflite_model_path = "my_model.tflite"
    with open(tflite_model_path, "wb") as f:
        f.write(tflite_model)
    
    print(f"TFLite 모델이 저장되었습니다: {tflite_model_path}")
    



    변환된 TFLite 모델의 입출력 확인

    # 변환된 TFLite 모델의 입출력 확인
    interpreter = tf.lite.Interpreter(model_path="style_transfer_model.tflite")
    interpreter.allocate_tensors()
    
    # 입력 텐서 정보 출력
    input_details = interpreter.get_input_details()
    print("Input Tensor Details:")
    for detail in input_details:
        print(detail)
    
    # 출력 텐서 정보 출력
    output_details = interpreter.get_output_details()
    print("Output Tensor Details:")
    for detail in output_details:
        print(detail)


    플러터 세팅

    플러터에서 텐서플로우 라이트 모델을 사용하기 위한 패키지로는 tflite 와 tflite_flutter 가 있다.


    GitHub - shaqian/flutter_tflite: Flutter plugin for TensorFlow Lite

    위 tflite는 마지막으로 업데이트 된지 3년이 넘었고, 현재 유지보수를 안하고 있고 제대로 적용이 되지 않았다.

    현재 플러터 SDK (sdk: '>=3.4.4 <4.0.0') 에서 필자가 정상적으로 테스트 성공한 것은 tflite_fluttertflite_flutter_helper 이다.


    GitHub - tensorflow/flutter-tflite

    위 tflite_flutter는 작년까지 업데이트 기록이 있는데, 최근에 유지보수를 안하는 것 같다.


    GitHub - am15h/tflite_flutter_helper: TensorFlow Lite Flutter Helper Library

    tflite_flutter_helper 인데, 마지막으로 업데이트 한지 3년이 넘었다. 역시 유지보수를 안하는 것 같다.


    dependencies:
      flutter:
        sdk: flutter
      cupertino_icons: ^1.0.6
      image: ^3.3.0
      image_picker: ^1.1.2
      tflite_flutter: ^0.9.1
      tflite_flutter_helper: ^0.3.1
      camera: ^0.9.8+1

    다만 tflite_flutter_helper 는 유지보수 안한지 오래되어, 현재 Flutter SDK와 Gradle 버전에 호환성 문제가 있어서 소스 코드를 수정해서 써야 했다.


    안드로이드에서 구동해보면 우선 The Android Gradle plugin supports only Kotlin Gradle plugin version 어쩌구 하는 오류가 날 것이다. 이 부분을 수정하고 나면, 또 다른 오류가 발생한다. 그 부분을 해결 해야한다. 그리고 또 실제로 적용시켜보려면 벼래별 오류들을 만나게 될 것이다. 추후에 자세히 다뤄보도록 하겠다.


    아무래도 이런 오픈소스 유지보수를 제대로 안할 것 같고 호환성 문제를 임시적으로 해결했더라도 추후에 다시 발생 할 수가 있기 때문에…. 장기적으로는 네이티브 코드로 구현할 줄 알아야 할것 같다.


    senspond

    안녕하세요. Red, Green, Blue 가 만나 새로운 세상을 만들어 나가겠다는 이상을 가진 개발자의 개인공간입니다.

    댓글 ( 0 )

    카테고리내 관련 게시글

    현재글에서 작성자가 발행한 같은 카테고리내 이전, 다음 글들을 보여줍니다

    @senspond

    >

    개발>모바일

    • 안드로이드(Android) OpenCV 세팅하기, Junit 테스트를 위한 Java 세팅까지

      안드로이드(Android) OpenCV 세팅하기, Junit 테스트를 위한 Java 세팅까지
        2025-01-22 (수) 05:59
      1. [현재글] 딥러닝 모델을 플러터로 모바일기기에 임베딩하기

        딥러닝 모델을 플러터로 모바일기기에 임베딩하기
          2024-11-14 (목) 12:12
        1. Flutter에서 수학 수식 표기하는 방법, 어떤 오픈소스 라이브러리를 써야 할까?

          Flutter에서 수학 수식 표기하는 방법, 어떤 오픈소스 라이브러리를 써야 할까? 에 대해서 사용해 본 오픈소스 라이브러리들을 비교 분석해서 정리해본 글입니다.
            2024-08-28 (수) 03:50
          1. Android Native C++ 로그를 안드로이드 logcat로 확인하기

            Android Native C++ 로그를 안드로이드 logcat로 확인하기
              2025-02-08 (토) 06:07
            1. 구글 플레이스토어 인앱결제 내부테스트 하는방법

              구글 플레이스토어 인앱결제 내부테스트 하는방법을 정리해봤습니다.
                2024-09-03 (화) 07:57