딥러닝은 모델을 학습시기며 학습된 모델을 이용하여 결과를 예측하거나 결과물을 생성해냅니다.
이러한 모델들은 학습이 완료된 뒤(혹은 학습중) 저장하여 사용할 수 있습니다.
모델을 저장하는 방법에는 다음 3가지 방법이 있습니다.
- ModelCheckpoint
- model.save()
- to_json(), to_yaml() and save_weight
한가지씩 살펴보도록 하겠습니다.
1.ModelCheckpoint
ModelCheckpoint는 콜백함수로 모델을 피팅할때 설정해 줍니다.
예시코드)
from keras import backend as K
from keras import layers as L
from keras.models import Model
from keras.regularizers import l2
from keras.callbacks import ModelCheckpoint,ReduceLROnPlateau,EarlyStopping
def Make_model(train,val):
model_ckpt = ModelCheckpoint('model_ckpt.h5',save_best_only=True)
reduce_lr = ReduceLROnPlateau(patience=8,verbose=1)
early_stop = EarlyStopping(patience=10,verbose=2,monitor='loss')
entry = L.Input(shape=(12, 12, 3))
x = L.SeparableConv2D(256, (3,3), activation='relu')(entry)
x = L.MaxPooling2D((2, 2))(x)
x = L.SeparableConv2D(512, (3, 3), activation='relu', padding='same')(x)
x = L.MaxPooling2D((2, 2))(x)
x = L.SeparableConv2D(1024, (2, 2), activation='relu', padding='same')(x)
x = L.GlobalMaxPooling2D()(x)
x = L.Dense(256)(x)
x = L.ReLU()(x)
x = L.Dense(64, kernel_regularizer=l2(2e-4))(x)
x = L.ReLU()(x)
x = L.Dense(27, activation='softmax')(x)
model = Model(entry,x)
model.summary()
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
history = model.fit_generator(train,validation_data=val,epochs=600,
callbacks=[model_ckpt,reduce_lr,early_stop],verbose=2)
keras.callbacks 안에 들어있으며 다음과 같은 Argument를 가집니다.
tf.keras.callbacks.ModelCheckpoint(
filepath, monitor='val_loss', verbose=0, save_best_only=False,
save_weights_only=False, mode='auto', save_freq='epoch', options=None, **kwargs
)
- filepath : 파일의 경로를 지정해 줍니다. 저장될 파일 이름을 적으면 .py 파일과 같은 위치에 지정한파일 이름으로 저장됩니다.
- monitor : 모니터링 하는 값을 정합니다.
- verbose : 0 또는 1로 표시될지 안될지 정합니다.
- save_best_only : True시 모니터링하는 값이 최적의 값이 아니면 덮어쓰지 않습니다.
- mode : save_best_only가 트루일때, 모니터링 하는 값의 최대, 최소중 어떤 점을 모니터링할지 정합니다. {max,min,auto}중 정하며 기본값은 auto입니다.
- save_freq : 'epoch' 또는 정수값을 가지며, 'epoch'일 경우 매 epoch마다 저장합니다. 정수일경우 batch수에 맞춰 돌아갑니다.
- option : tf.train.CheckpointOptions에 의해서 가중치(weight)만 저장하게 할 수 있습니다.
2. model.save()
save() 함수는 가장 기본적인 저장 방법입니다. Tensorflow Savemodel 형식이 기본입니다. 괄호안에 경로를 지정하면 경로를 통해 저장합니다.
- 모델의 아키텍처 및 구성
- 훈련 중에 학습된 모델의 가중치 값
- 모델의 컴파일 정보(compile()이 호출된 경우)
- 존재하는 옵티마이저와 그 상태(훈련을 중단한 곳에서 다시 시작할 수 있게 해줌)
다만 여러 경우에서 h5형식을 쓰는 경우가 많습니다. h5형식으로 저장하는 방법은 다음과 같습니다.
- 파일확장자에 .h5 ,.keras를 명시함
- save함수에 format이 h5임을 전해줌.
model.save('save_model') # Tensorflow SaveModel 방식 저장
model.save('save_model.h5') #keras h5
3.to_json(), to_yaml() and save_weight
위의 두 방법이 기본적으론 아키텍처와 가중치를 모두 저장하는 방법이었다면 이번 방법은 아키텍처와 가중치를 따로 저장하는 방법입니다. 다음과 같은 장점이 있습니다.
- 가중치저장을 따로 하므로 학습에 오류가 생길시 복구가 가능하다.
- 아키텍처를 그대로 가져다 불러올수 있어 앙상블기법등 같은 아키텍처를 사용하기에 좋다.
- 가중치에 가중치를 더해서 학습을 이어서 진행 할 수 있다.
model의 아키텍처는 .to_json을 통해 json형식으로 변경되고 저장할 수 있습니다.
model_json = model.to_json()
with open('model_json.json', 'w') as f:
f.write(model_json)
마찬가지로 yaml을 통해 저장 될수 있습니다.
model_yaml=model.to_yaml()
with open('model_yaml.yaml', 'w') as f:
f.write(model_yaml)
모델의 가중치는 save_weight를 통해 h5형식으로 저장됩니다.
model.save_weights('save_weight.h5')
1,2번 방법이 모델을 한번에 불러올 수 있다면 3번 방법은 두가지를 한꺼번에 불러와야 합니다.
한가지만 불러올경우 다음과 같은 오류가 발생합니다.
C:\ProgramData\Anaconda3\envs\Project_f\python.exe D:/ML2/Plant_Disease/py/Load_model.py
Using TensorFlow backend.
2020-09-23 21:54:07.737753: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library cudart64_100.dll
Traceback (most recent call last):
File "D:/ML2/Plant_Disease/py/Load_model.py", line 11, in <module>
load_model('D:\ML2\Plant_Disease\small_weights.h5')
File "D:/ML2/Plant_Disease/py/Load_model.py", line 4, in load_model
model = load_model(file)
File "C:\ProgramData\Anaconda3\envs\Project_f\lib\site-packages\keras\engine\saving.py", line 492, in load_wrapper
return load_function(*args, **kwargs)
File "C:\ProgramData\Anaconda3\envs\Project_f\lib\site-packages\keras\engine\saving.py", line 584, in load_model
model = _deserialize_model(h5dict, custom_objects, compile)
File "C:\ProgramData\Anaconda3\envs\Project_f\lib\site-packages\keras\engine\saving.py", line 270, in _deserialize_model
model_config = h5dict['model_config']
File "C:\ProgramData\Anaconda3\envs\Project_f\lib\site-packages\keras\utils\io_utils.py", line 318, in __getitem__
raise ValueError('Cannot create group in read-only mode.')
ValueError: Cannot create group in read-only mode.
Process finished with exit code 1
읽기전용 모드라는 오류라 처음 본다면 쉽게 해결하기 어려울 것이라 예상됩니다.
각각의 방법을 다적은 예시는 다음과 같습니다.(모델은 점자 CNN 모델입니다.)
from keras import backend as K
from keras import layers as L
from keras.models import Model
from keras.regularizers import l2
from keras.callbacks import ModelCheckpoint,ReduceLROnPlateau,EarlyStopping
def Make_model(train,val):
K.clear_session()
model_ckpt = ModelCheckpoint('model_ckpt.h5',save_best_only=True)
reduce_lr = ReduceLROnPlateau(patience=8,verbose=1)
early_stop = EarlyStopping(patience=10,verbose=2,monitor='loss')
entry = L.Input(shape=(12, 12, 3))
x = L.SeparableConv2D(256, (3,3), activation='relu')(entry)
x = L.MaxPooling2D((2, 2))(x)
x = L.SeparableConv2D(512, (3, 3), activation='relu', padding='same')(x)
x = L.MaxPooling2D((2, 2))(x)
x = L.SeparableConv2D(1024, (2, 2), activation='relu', padding='same')(x)
x = L.GlobalMaxPooling2D()(x)
x = L.Dense(256)(x)
x = L.ReLU()(x)
x = L.Dense(64, kernel_regularizer=l2(2e-4))(x)
x = L.ReLU()(x)
x = L.Dense(27, activation='softmax')(x)
model = Model(entry,x)
model.summary()
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
history = model.fit_generator(train,validation_data=val,epochs=600,
callbacks=[model_ckpt,reduce_lr,early_stop],verbose=2)
model.save('save_model.h5')
model.save_weights('save_weight.h5')
model_json = model.to_json()
with open('model_json.json', 'w') as f:
f.write(model_json)
model_yaml=model.to_yaml()
with open('model_json.yaml', 'w') as f:
f.write(model_yaml)
return history
'Back > Deep Learning' 카테고리의 다른 글
[Tensorflow] Tensorflowlite를 이용한 Image classification model maker (1) | 2020.10.08 |
---|---|
[Keras] 모델 불러오기 (0) | 2020.09.24 |
[Keras]ImageGenerator 사용하기. (0) | 2020.09.18 |
[Python] [CNN]점자번역 프로그램(7) - 정리 (1) | 2020.09.03 |
[Python] [CNN]점자번역 프로그램(6) (0) | 2020.09.02 |