본문 바로가기

ML & AI Theory

tf.keras 모델의 저장과 복원

반응형

tf.keras API로 만든 모델을 저장하고 복원하는 것은 아주 쉽습니다.  먼저 앞서 만든 모델 가중치를 save_weights()메서드로 저장해 보겠습니다.

 

model.save_weights('./model/simple_weights.h5')

이 코드를 실행하면 h45 파일을 생성하고 모든 층의 가중치를 저장합니다. save_weights메서드는 기본적으로 텐서플로의 체크포인트(checkpoint)포맷으로 가중치를 저장합니다. save_format 매개변수를 'h5'로 지정하여 HDF5파일 포맷으로 저장할 수 있습니다. 이 메서드는 똑똑학도 파일 이름의 확장자가 .h5이면 자동으로 HDF5포맷으로 저장합니다. 

 

저장된 가중치를 사용하려면 새로운 모델을 만들고 load_weight()메서드를 사용하여 가중치를 로드합니다.

아래 코드는 모델을 훈련시키지 않고 이전에 학습한 가중치를 사용합니다. 가중치가 올바르게 로드되었는지 확인하기 위해 테스트 세트에 대한 손실 점수를 계산해 보겠습니다.

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(units=1, input_dim=1))
model.compile(optimizer='sgd', loss='mse')
model.load_weights('./model/simple_weights.h5')

 

테스트 세트를 평가하려면 evaluate() 메서드를 사용합니다.

model.evaluate(x_test, y_test)

 

저장되었던 가중치가 잘 적용된 것 같습니다. 가중치 외에 모델 전체를 저장을 할 경우에 tf.keras 모델은 save() 메서드를 사용하여 가중치와 네트워크 구졲지 HDF5포맷으로 저장할 수 있습니다.

model.save('./data/simple_model.h5')

저장한 모델을 로드하려면 load_model()함수를 사용합니다. 저장된 모델 구조를 그대로 사용하기 때문에 층을 다시 추가할 필요가 없습니다. 다음 코드에서 모델을 로드하고 바로 테스트 세트를 평가합니다.

model = tf.keras.models.load_model('./data/simple_model.h5')
model.evaluate(x_test, y_test)

 

모델을 훈련하는 동안 ModelCheckpoint 콜백을 사용하여 최고의 성능을 내는 가중치를 저장 할 수 있습니다. ModelCheckpoint 콜백은 더 이상 성능이 개선되지 않을 떄 훈련을 멈추게 하는 EarlyStopping콜백과 함께 사용하는 경우가 많습니다. EarlyStopping 콜백 클래스는 기본적으로 점증 손실을 모니터링합니다. patience매개변수가 지정한 에포크 횟수 동안 모니터링 지표가 개선되지 않으면 훈련을 중지합니다.

 

다음 코드는 ModelCheckpoint 콜백을 사용하는 예시입니다. 검증 손실을 모니터링하면서 (mornitor='val_loss')최상의 모델 가중치를 저장합니다. (save_best_only=True)

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(units=1, input_dim=1))
model.compile(optimizer='sgd', loss='mse')
callback_list = [tf.keras.callbacks.ModelCheckpoint(filepath='./data/my_model.h5', monitor='val_loss', save_nest_only=True), tf.keras.callbacks.EarlyStopping(patience=5)]
history = model.fit(x_train, y_train, epochs=500, validation_split=0.2, callbacks=callback_list)

 

손실 그래프를 그려 보면 검증 손실이 감소되지 않으므로 훈련이 일찍 멈춘 것을 확인할 수 있습니다.

epochs = np.arange(1, len(history.history['loss'])+1)
plt.plot(epochs, history.history['loss'], label='Training loss')
plt.plot(epochs, history.history['val_loss'], label='validation loss')
plt.xlabel('Epochs')
plt.ylabel('loss')
plt.legend()
plt.show()

 

저장된 모델을 로드한 후 ModelCheckpoint 콜백에서 저장한 가중치를 load_weights()메서드로 적재합니다. 이렇게 하면 모델 구조와 함께 최상의 모델 가중치를 유지할 수 있습니다.

model = tf.keras.models.load_model('./data/simple_model.h5')
model.load_weights('./data/my_model.h5')
model.evaluate(x_test, y_test)

 

반응형