@senspond
>
딥러닝 모델을 플러터로 모바일기기에 임베딩하기
오랜만에 블로그에 글을 쓴다.
MNIST 데이터셋으로 학습한 손글씨 분류 모델을 Flutter로 임베딩 해봤다.
TensorFlow Lite (TFLite)로 변환을 해야 모바일 기기에 임베딩을 할 수가 있다. tflite는 TensorFlow의 경량 버전으로, 모바일 및 임베디드 장치에서 딥러닝 모델을 추론 할 수 있도록 설계된 프레임워크이다. TensorFlow Lite는 경량성, 효율성, 빠른 추론을 목표로 하며, 스마트폰, IoT 장치, 마이크로 컨트롤러와 같은 리소스가 제한된 장치에서도 머신러닝 추론을 수행할 수 있게 한다. 양자화 기법 등을 통해 모델 사이즈를 줄일 수 있고, 오직 추론만 가능하도록 설계되어 있다.
먼저 ONNX 로 변환한다음, 텐서플로우 모델로 변환하고 최종적으로 텐서플로우 라이트(tflite)로 변환해야 한다. 변환과정이 다소 번거롭고 복잡하기에 변환하는 과정에서 네트워크 정보가 손실되는 경우들이 생긴다. 파이토치, onnx, 텐서플로우, 텐서플로우 라이트 모두 고유한 연산자를 사용하는데, 일부 특정 프레임워크에서 사용되는 비표준 연산자 같은 경우 제대로 맵핑이 되지 않을 수가 있다. 딥러닝 연구를 할 때는 파이토치를 많이 사용하지만, tflite로 모바일 기기 임베딩을 할 때는 텐서플로우로 개발을 하는 것이 정신건강에 좋을 것 같다.
import torch
import torch.onnx
# PyTorch 모델 정의 (예시 모델)
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = torch.nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
# 모델 인스턴스 생성 및 저장
model = SimpleModel()
model.eval()
# 더미 입력 정의 (모델의 입력 형식과 일치해야 함)
dummy_input = torch.randn(1, 10)
# 모델을 ONNX 형식으로 저장
onnx_model_path = "simple_model.onnx"
torch.onnx.export(
model, dummy_input, onnx_model_path,
input_names=['input'], output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
print(f"ONNX 모델이 저장되었습니다: {onnx_model_path}")
import onnx
from onnx_tf.backend import prepare
# ONNX 모델 로드
onnx_model = onnx.load(onnx_model_path)
# TensorFlow 모델로 변환
tf_rep = prepare(onnx_model)
# 변환된 TensorFlow 모델 저장
tf_model_path = "simple_model_tf"
tf_rep.export_graph(tf_model_path)
print(f"TensorFlow 모델이 저장되었습니다: {tf_model_path}")
import tensorflow as tf
# TensorFlow 모델 로드
saved_model_dir = tf_model_path
# TFLite Converter 사용
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
# 모델 최적화 옵션 (선택사항)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# TFLite 모델로 변환
tflite_model = converter.convert()
# TFLite 모델 저장
tflite_model_path = "simple_model.tflite"
with open(tflite_model_path, "wb") as f:
f.write(tflite_model)
print(f"TFLite 모델이 저장되었습니다: {tflite_model_path}")
import tensorflow as tf
# TensorFlow 모델 로드
saved_model_dir = tf_model_path
# TFLite Converter 사용
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
# 모델 최적화 옵션 (선택사항)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# TFLite 모델로 변환
tflite_model = converter.convert()
# TFLite 모델 저장
tflite_model_path = "simple_model.tflite"
with open(tflite_model_path, "wb") as f:
f.write(tflite_model)
print(f"TFLite 모델이 저장되었습니다: {tflite_model_path}")
.h5 혹은 .keras 형태로 저장된 케라스 모델을 텐서플로우 모델로 변환 후 텐서플로우 라이트(tflite) 로 변환해야 한다.
import tensorflow as tf
# Keras 모델 (.h5 /.keras 파일) 로드
h5_model_path = "my_model.h5" # Keras 모델이 저장된 경로
keras_model = tf.keras.models.load_model(h5_model_path)
# TFLiteConverter를 사용하여 Keras 모델을 TFLite 모델로 변환
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
# 모델 최적화 옵션 설정 (선택사항)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# TFLite 모델 변환 수행
tflite_model = converter.convert()
# TFLite 모델을 파일로 저장
tflite_model_path = "my_model.tflite"
with open(tflite_model_path, "wb") as f:
f.write(tflite_model)
print(f"TFLite 모델이 저장되었습니다: {tflite_model_path}")
# 변환된 TFLite 모델의 입출력 확인
interpreter = tf.lite.Interpreter(model_path="style_transfer_model.tflite")
interpreter.allocate_tensors()
# 입력 텐서 정보 출력
input_details = interpreter.get_input_details()
print("Input Tensor Details:")
for detail in input_details:
print(detail)
# 출력 텐서 정보 출력
output_details = interpreter.get_output_details()
print("Output Tensor Details:")
for detail in output_details:
print(detail)
플러터에서 텐서플로우 라이트 모델을 사용하기 위한 패키지로는 tflite 와 tflite_flutter 가 있다.
GitHub - shaqian/flutter_tflite: Flutter plugin for TensorFlow Lite
위 tflite는 마지막으로 업데이트 된지 3년이 넘었고, 현재 유지보수를 안하고 있고 제대로 적용이 되지 않았다.
현재 플러터 SDK (sdk: '>=3.4.4 <4.0.0') 에서 필자가 정상적으로 테스트 성공한 것은 tflite_flutter 와tflite_flutter_helper 이다.
GitHub - tensorflow/flutter-tflite
위 tflite_flutter는 작년까지 업데이트 기록이 있는데, 최근에 유지보수를 안하는 것 같다.
GitHub - am15h/tflite_flutter_helper: TensorFlow Lite Flutter Helper Library
tflite_flutter_helper 인데, 마지막으로 업데이트 한지 3년이 넘었다. 역시 유지보수를 안하는 것 같다.
dependencies:
flutter:
sdk: flutter
cupertino_icons: ^1.0.6
image: ^3.3.0
image_picker: ^1.1.2
tflite_flutter: ^0.9.1
tflite_flutter_helper: ^0.3.1
camera: ^0.9.8+1
다만 tflite_flutter_helper 는 유지보수 안한지 오래되어, 현재 Flutter SDK와 Gradle 버전에 호환성 문제가 있어서 소스 코드를 수정해서 써야 했다.
안드로이드에서 구동해보면 우선 The Android Gradle plugin supports only Kotlin Gradle plugin version 어쩌구 하는 오류가 날 것이다. 이 부분을 수정하고 나면, 또 다른 오류가 발생한다. 그 부분을 해결 해야한다. 그리고 또 실제로 적용시켜보려면 벼래별 오류들을 만나게 될 것이다. 추후에 자세히 다뤄보도록 하겠다.
아무래도 이런 오픈소스 유지보수를 제대로 안할 것 같고 호환성 문제를 임시적으로 해결했더라도 추후에 다시 발생 할 수가 있기 때문에…. 장기적으로는 네이티브 코드로 구현할 줄 알아야 할것 같다.
안녕하세요. Red, Green, Blue 가 만나 새로운 세상을 만들어 나가겠다는 이상을 가진 개발자의 개인공간입니다.
현재글에서 작성자가 발행한 같은 카테고리내 이전, 다음 글들을 보여줍니다
@senspond
>