Load Pretrained model in pytorch
Pretrained model
- pth๋ก ์ ์ฅ๋ torch pretrained model(weight)๋ฅผ ๋ถ๋ฌ์์ ์ฌ์ฉ
- weight์ ์ผ๋ถ๋ง ๋ถ๋ฌ์์ ์ฌ์ฉํ ์ ์๋ค.
- pth = dictionary ๋ก ๊ตฌ์ฑ๋๋ค.
Get format
pth ํ์ผ์ Dictionary ํํ๋ก ์ ์ฅ๋์ด ์๋ค. pytorch์ load๋ฅผ ํตํด์ ๋ถ๋ฌ์ฌ ์ ์๋ค.
import torch
model = torch.load('model.pth')
print(model.keys())
model.keys()
๋ฅผ ์ฌ์ฉํด์ key ๊ฐ๋ค์ ๋ถ๋ฌ์ฌ ์ ์๋๋ฐ, ์ด๊ฒ์ผ๋ก ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ํ์
ํ ์ ์๋ค. ํ์ฌ ์์ ๋ก ์ฌ์ฉํ๊ณ ์๋ pth ํ์ผ์ mobilenet-ssd-v1 ๋ชจ๋ธ์ mAP 0.675 pretrained weight ํ์ผ์ด๋ค.
์ด ๋ชจ๋ธ์ ๊ฒฝ์ฐ์๋ object detection ๋ชจ๋ธ์ด๊ธฐ ๋๋ฌธ์ regression๊ณผ classfication์ด ๋ชจ๋ ์กด์ฌํ๋ค. ์ฌ๊ธฐ์
- base_net์ weight, bias, running mean, running var
- extras์ weight, bias
- classification_header์ weight, bias
- regression_header์ weight, bias
๊ฐ ๋ ์ด์ด๋ณ๋ก ์ ์ฅ๋์ด ์๋ค. (๋ชจ๋ธ๋ง๋ค ํฌ๋งท์ด ๋ค๋ฅผ ์ ์๋ค.)
์ด ๋ชจ๋ธ์์ ๋ถ๋ฅํ๋ class์ ๊ฐ์๋ฅผ ์ป๊ธฐ ์ํด์๋ ๋ง์ง๋ง output layer์ ์์ดํ ๊ฐ์๋ฅผ ํ์ธํ๋ฉด ๋๋๋ฐ,
output_len = len(list(model.values())[-1])
์ด๋ ๊ฒ ํ์ธํ ์ ์๋ค. ๋ค๋ฅธ ๋ฐ์ดํฐ ํฌ๋งท์์๋ ์์ ๋์ ๋๋ฆฌ ์๋์ value๊ฐ ์ ์ฅ๋์ด ์์ ์๋ ์๋ค. ๊ทธ๋ฐ ๊ฒฝ์ฐ์๋
output_len = len(list(model['{your-key}'].values())[-1])
์์ ๊ฐ์ด ์ ๊ทผํ ์ ์๋ค. ์์ ๋ชจ๋ธ์ ๊ฒฝ์ฐ classification ๊ฐ์๊ฐ 24์ด๋ค.
Extract backbone
๋ด ๊ฒฝ์ฐ์๋ pretrained model๊ณผ ๋ค๋ฅธ ์ปค์คํ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ ์์ ์ด๋ผ class์ ๊ฐ์๊ฐ ๋ฌ๋๋ค. ํ์ง๋ง classification๊ณผ ๊ด๊ณ๊ฐ ์ ์, bounding box๋ฅผ ํ์ตํ ๋ฐ์ดํฐ์ธ base net ํํธ๋ ๊ฐ์ ธ๊ฐ๋๊ฒ ํ์ต ์๊ฐ์ ์ค์ผ ์ ์๋ ๋ฐฉ๋ฒ์ด๋ผ๊ณ ์๊ฐํด์ pretrained model์ ์ผ๋ถ weight๋ง ์๋ผ์ ์ ์ฉํ๊ธฐ๋ก ํ๋ค.
Extract weights of specific module from pretrained model file
์ ๋งํฌ์ pytorch forum์ ์ฐธ๊ณ ํ๋ค.
dictionary key์ ์ด๋ฆ์ ๊ธฐ์ค์ผ๋ก ์ dictionary์ ์ถ๊ฐํ๋ค.
new_weights = {}
for key, value in model.items():
if key.startwith('base'): # key์ ์์ ์ด๋ฆ์ด base์ธ ๊ฒฝ์ฐ
new_weight[key] = value # ์ dictionary์ ์ถ๊ฐ
์๋ก ๋ง๋ dictionary๋ฅผ pth ํ์ผ๋ก ์ ์ฅํ๋ค.
torch.save(new_weights, {save_path})
Load model and apply
model load๋ torch.load_state_dict
๋ก ํ ์ ์๋ค.
torch.load_state_dict
๋ argument๋ก path๊ฐ ์๋๋ผ dictionary object๋ฅผ ์๊ตฌํ๊ธฐ ๋๋ฌธ์, torch.load
๋ก ๋จผ์ ๋ถ๋ฌ์จ dictionary๋ฅผ ๋ฃ์ด์ฃผ์ด์ผํ๋ค.
model = torch.load_state_dict(torch.load('{model_path}'))
๊ทธ๋ฐ๋ฐ ์ง๊ธ์ base_net๋ง backbone์ผ๋ก ์ ์ฅ๋์ด ์๋ค๋ฉด ๋ชจ๋ธ์ ๋ท๋ถ๋ถ์ธ extra๋ classification, regression ๋ถ๋ถ์ ์ด๋ป๊ฒ weight๋ฅผ ๊ฐ์ ธ๊ฐ๊น? initial value๋ฅผ ์์์ ์ง์ด๋ฃ์ด์ฃผ๋? ํ๋ฉด ์๋๋ค. ์ด๋๋ก ์คํํ ๋ layer ์ด๋ฆ์ด ๋ง์ง ์์ผ๋ฉด ๋น์ฐํ key๋ก ๋ง์ถฐ์ฃผ์ด์ผํ๊ณ , ๋ชจ๋ธ์ ํฌ๊ธฐ๋ ์ด์ธ์ ๊ฒ๋ค์ด ๋ง์ง ์์ผ๋ฉด ๋ง์ถฐ์ฃผ์ด์ผํ๋ค.
# ๋ชจ๋ธ ๋ ์ด์ด๊ฐ ์ ๋ง๋ ๊ฒฝ์ฐ ๋ฐ์
RuntimeError: Error(s) in loading state_dict for ....
๋ชจ๋ธ์ ๋ ์ด์ด๊ฐ ๋์น๊ฑฐ๋ ๋ถ์กฑํ ๋ ํน์ key๊ฐ ๋ค๋ฅผ ๋ ๋๋จธ์ง๋ฅผ ๋ฌด์ํด๋ฒ๋ฆฌ๋ option์ ์ถ๊ฐํ ์ ์๋ค. ๋ค๋ง key ์ด๋ฆ์ด ๋ค๋ฅด๋ฉด ์ฃ๋ค ๋ฌด์ํด๋ฒ๋ฆฌ๊ธฐ ๋๋ฌธ์ ํ์ํ ๋ถ๋ถ๊ณผ ํ์ฌ ์ปค์คํ ๋ชจ๋ธ์ key๊ฐ ๊ฐ์์ง ๋ค๋ฅธ์ง ํ์ธํ๊ณ ๋ค๋ฅด๋ค๋ฉด ์์ ํด์ค ๋ค์ ํด์ผํ๋ค.
model = CustomModel.load_state_dict(torch.load('{model_path}'), strict=False)
์ด๋ ๊ฒ ํ๋ฉด ์์ ๋ชจ๋ธ์์ extra, classification, regression ๋ถ๋ถ ๋ ์ด์ด์ weight๋ฅผ ๋ก๋ํ์ง ๋ชปํ๋ ๊ฒ์ ๋ฌด์ํ๊ณ ์งํํด์ค๋ค.