[PyTorch] pretrained model load/save, pretrained model ํŽธ์ง‘

2022. 9. 19. 15:32ยท๐Ÿฌ ML & Data/โ” Q & etc.
728x90

Load Pretrained model in pytorch

Pretrained model

  • pth๋กœ ์ €์žฅ๋œ torch pretrained model(weight)๋ฅผ ๋ถˆ๋Ÿฌ์™€์„œ ์‚ฌ์šฉ
  • weight์˜ ์ผ๋ถ€๋งŒ ๋ถˆ๋Ÿฌ์™€์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.
  • pth = dictionary ๋กœ ๊ตฌ์„ฑ๋œ๋‹ค.

Get format

pth ํŒŒ์ผ์€ Dictionary ํ˜•ํƒœ๋กœ ์ €์žฅ๋˜์–ด ์žˆ๋‹ค. pytorch์˜ load๋ฅผ ํ†ตํ•ด์„œ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋‹ค.

import torch

model = torch.load('model.pth')
print(model.keys())

model.keys() ๋ฅผ ์‚ฌ์šฉํ•ด์„œ key ๊ฐ’๋“ค์„ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋Š”๋ฐ, ์ด๊ฒƒ์œผ๋กœ ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ํŒŒ์•…ํ•  ์ˆ˜ ์žˆ๋‹ค. ํ˜„์žฌ ์˜ˆ์ œ๋กœ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ๋Š” pth ํŒŒ์ผ์€ mobilenet-ssd-v1 ๋ชจ๋ธ์˜ mAP 0.675 pretrained weight ํŒŒ์ผ์ด๋‹ค.

์ด ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ์—๋Š” object detection ๋ชจ๋ธ์ด๊ธฐ ๋•Œ๋ฌธ์— regression๊ณผ classfication์ด ๋ชจ๋‘ ์กด์žฌํ•œ๋‹ค. ์—ฌ๊ธฐ์—

  • base_net์˜ weight, bias, running mean, running var
  • extras์˜ weight, bias
  • classification_header์˜ weight, bias
  • regression_header์˜ weight, bias

๊ฐ€ ๋ ˆ์ด์–ด๋ณ„๋กœ ์ €์žฅ๋˜์–ด ์žˆ๋‹ค. (๋ชจ๋ธ๋งˆ๋‹ค ํฌ๋งท์ด ๋‹ค๋ฅผ ์ˆ˜ ์žˆ๋‹ค.)

์ด ๋ชจ๋ธ์—์„œ ๋ถ„๋ฅ˜ํ•˜๋Š” class์˜ ๊ฐœ์ˆ˜๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋งˆ์ง€๋ง‰ output layer์˜ ์•„์ดํ…œ ๊ฐœ์ˆ˜๋ฅผ ํ™•์ธํ•˜๋ฉด ๋˜๋Š”๋ฐ,

output_len = len(list(model.values())[-1])

์ด๋ ‡๊ฒŒ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ ํฌ๋งท์—์„œ๋Š” ์ƒ์œ„ ๋”•์…”๋„ˆ๋ฆฌ ์•„๋ž˜์— value๊ฐ€ ์ €์žฅ๋˜์–ด ์žˆ์„ ์ˆ˜๋„ ์žˆ๋‹ค. ๊ทธ๋Ÿฐ ๊ฒฝ์šฐ์—๋Š”

output_len = len(list(model['{your-key}'].values())[-1])

์œ„์™€ ๊ฐ™์ด ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๋‹ค. ์˜ˆ์ œ ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ classification ๊ฐœ์ˆ˜๊ฐ€ 24์ด๋‹ค.

Extract backbone

๋‚ด ๊ฒฝ์šฐ์—๋Š” pretrained model๊ณผ ๋‹ค๋ฅธ ์ปค์Šคํ…€ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•  ์˜ˆ์ •์ด๋ผ class์˜ ๊ฐœ์ˆ˜๊ฐ€ ๋‹ฌ๋ž๋‹ค. ํ•˜์ง€๋งŒ classification๊ณผ ๊ด€๊ณ„๊ฐ€ ์ ์€, bounding box๋ฅผ ํ•™์Šตํ•œ ๋ฐ์ดํ„ฐ์ธ base net ํŒŒํŠธ๋Š” ๊ฐ€์ ธ๊ฐ€๋Š”๊ฒŒ ํ•™์Šต ์‹œ๊ฐ„์„ ์ค„์ผ ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์ด๋ผ๊ณ  ์ƒ๊ฐํ•ด์„œ pretrained model์˜ ์ผ๋ถ€ weight๋งŒ ์ž˜๋ผ์„œ ์ ์šฉํ•˜๊ธฐ๋กœ ํ–ˆ๋‹ค.

Extract weights of specific module from pretrained model file

์œ„ ๋งํฌ์˜ pytorch forum์„ ์ฐธ๊ณ ํ–ˆ๋‹ค.

dictionary key์˜ ์ด๋ฆ„์„ ๊ธฐ์ค€์œผ๋กœ ์ƒˆ dictionary์— ์ถ”๊ฐ€ํ•œ๋‹ค.

new_weights = {}
for key, value in model.items():
    if key.startwith('base'): # key์˜ ์‹œ์ž‘ ์ด๋ฆ„์ด base์ธ ๊ฒฝ์šฐ
        new_weight[key] = value # ์ƒˆ dictionary์— ์ถ”๊ฐ€

์ƒˆ๋กœ ๋งŒ๋“  dictionary๋ฅผ pth ํŒŒ์ผ๋กœ ์ €์žฅํ•œ๋‹ค.

torch.save(new_weights, {save_path})

Load model and apply

model load๋Š” torch.load_state_dict ๋กœ ํ•  ์ˆ˜ ์žˆ๋‹ค.

torch.load_state_dict ๋Š” argument๋กœ path๊ฐ€ ์•„๋‹ˆ๋ผ dictionary object๋ฅผ ์š”๊ตฌํ•˜๊ธฐ ๋•Œ๋ฌธ์—, torch.load ๋กœ ๋จผ์ € ๋ถˆ๋Ÿฌ์˜จ dictionary๋ฅผ ๋„ฃ์–ด์ฃผ์–ด์•ผํ•œ๋‹ค.

model = torch.load_state_dict(torch.load('{model_path}'))

๊ทธ๋Ÿฐ๋ฐ ์ง€๊ธˆ์€ base_net๋งŒ backbone์œผ๋กœ ์ €์žฅ๋˜์–ด ์žˆ๋‹ค๋ฉด ๋ชจ๋ธ์˜ ๋’ท๋ถ€๋ถ„์ธ extra๋‚˜ classification, regression ๋ถ€๋ถ„์€ ์–ด๋–ป๊ฒŒ weight๋ฅผ ๊ฐ€์ ธ๊ฐˆ๊นŒ? initial value๋ฅผ ์•Œ์•„์„œ ์ง‘์–ด๋„ฃ์–ด์ฃผ๋‚˜? ํ•˜๋ฉด ์•„๋‹ˆ๋‹ค. ์ด๋Œ€๋กœ ์‹คํ–‰ํ•  ๋•Œ layer ์ด๋ฆ„์ด ๋งž์ง€ ์•Š์œผ๋ฉด ๋‹น์—ฐํžˆ key๋กœ ๋งž์ถฐ์ฃผ์–ด์•ผํ•˜๊ณ , ๋ชจ๋ธ์˜ ํฌ๊ธฐ๋‚˜ ์ด์™ธ์˜ ๊ฒƒ๋“ค์ด ๋งž์ง€ ์•Š์œผ๋ฉด ๋งž์ถฐ์ฃผ์–ด์•ผํ•œ๋‹ค.

# ๋ชจ๋ธ ๋ ˆ์ด์–ด๊ฐ€ ์•ˆ ๋งž๋Š” ๊ฒฝ์šฐ ๋ฐœ์ƒ
RuntimeError: Error(s) in loading state_dict for ....

๋ชจ๋ธ์˜ ๋ ˆ์ด์–ด๊ฐ€ ๋„˜์น˜๊ฑฐ๋‚˜ ๋ถ€์กฑํ•  ๋•Œ ํ˜น์€ key๊ฐ€ ๋‹ค๋ฅผ ๋•Œ ๋‚˜๋จธ์ง€๋ฅผ ๋ฌด์‹œํ•ด๋ฒ„๋ฆฌ๋Š” option์„ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ๋‹ค. ๋‹ค๋งŒ key ์ด๋ฆ„์ด ๋‹ค๋ฅด๋ฉด ์ฃ„๋‹ค ๋ฌด์‹œํ•ด๋ฒ„๋ฆฌ๊ธฐ ๋•Œ๋ฌธ์— ํ•„์š”ํ•œ ๋ถ€๋ถ„๊ณผ ํ˜„์žฌ ์ปค์Šคํ…€ ๋ชจ๋ธ์˜ key๊ฐ€ ๊ฐ™์€์ง€ ๋‹ค๋ฅธ์ง€ ํ™•์ธํ•˜๊ณ  ๋‹ค๋ฅด๋‹ค๋ฉด ์ˆ˜์ •ํ•ด์ค€ ๋’ค์— ํ•ด์•ผํ•œ๋‹ค.

model = CustomModel.load_state_dict(torch.load('{model_path}'), strict=False)

์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ์˜ˆ์ œ ๋ชจ๋ธ์—์„œ extra, classification, regression ๋ถ€๋ถ„ ๋ ˆ์ด์–ด์˜ weight๋ฅผ ๋กœ๋“œํ•˜์ง€ ๋ชปํ•˜๋Š” ๊ฒƒ์„ ๋ฌด์‹œํ•˜๊ณ  ์ง„ํ–‰ํ•ด์ค€๋‹ค.

728x90
์ €์ž‘์žํ‘œ์‹œ ๋น„์˜๋ฆฌ ๋ณ€๊ฒฝ๊ธˆ์ง€ (์ƒˆ์ฐฝ์—ด๋ฆผ)

'๐Ÿฌ ML & Data > โ” Q & etc.' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

[On-Device AI] ๋ผ์ฆˆ๋ฒ ๋ฆฌํŒŒ์ด์—์„œ Ollama๋กœ llama3.2 ๋™์ž‘์‹œํ‚ค๊ธฐ  (0) 2025.02.11
[Model Compression] ๋ชจ๋ธ ์–‘์žํ™”(Model Optimization) with Tensorflow  (0) 2024.05.21
[Math] Mathematics for Machine Learning 2. Linear Algebra  (0) 2024.01.10
[Data] ์ „๋™ ๋ชจํ„ฐ ์ด์ƒํƒ์ง€ ๋ฐ ๋ถ„๋ฅ˜๋ฅผ ์œ„ํ•œ ์ฃผํŒŒ์ˆ˜ ๋ถ„์„  (0) 2023.09.26
'๐Ÿฌ ML & Data/โ” Q & etc.' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€
  • [On-Device AI] ๋ผ์ฆˆ๋ฒ ๋ฆฌํŒŒ์ด์—์„œ Ollama๋กœ llama3.2 ๋™์ž‘์‹œํ‚ค๊ธฐ
  • [Model Compression] ๋ชจ๋ธ ์–‘์žํ™”(Model Optimization) with Tensorflow
  • [Math] Mathematics for Machine Learning 2. Linear Algebra
  • [Data] ์ „๋™ ๋ชจํ„ฐ ์ด์ƒํƒ์ง€ ๋ฐ ๋ถ„๋ฅ˜๋ฅผ ์œ„ํ•œ ์ฃผํŒŒ์ˆ˜ ๋ถ„์„
darly213
darly213
ํ˜ธ๋ฝํ˜ธ๋ฝํ•˜์ง€ ์•Š์€ ๊ฐœ๋ฐœ์ž๊ฐ€ ๋˜์–ด๋ณด์ž
  • darly213
    ERROR DENY
    darly213
  • ์ „์ฒด
    ์˜ค๋Š˜
    ์–ด์ œ
    • ๋ถ„๋ฅ˜ ์ „์ฒด๋ณด๊ธฐ (97)
      • ๐Ÿฌ ML & Data (50)
        • ๐ŸŒŠ Computer Vision (2)
        • ๐Ÿ“ฎ Reinforcement Learning (12)
        • ๐Ÿ“˜ ๋…ผ๋ฌธ & ๋ชจ๋ธ ๋ฆฌ๋ทฐ (8)
        • ๐Ÿฆ„ ๋ผ์ดํŠธ ๋”ฅ๋Ÿฌ๋‹ (3)
        • โ” Q & etc. (5)
        • ๐ŸŽซ ๋ผ์ดํŠธ ๋จธ์‹ ๋Ÿฌ๋‹ (20)
      • ๐Ÿฅ Web (21)
        • โšก Back-end | FastAPI (2)
        • โ›… Back-end | Spring (5)
        • โ” Back-end | etc. (9)
        • ๐ŸŽจ Front-end (4)
      • ๐ŸŽผ Project (8)
        • ๐ŸงŠ Monitoring System (8)
      • ๐Ÿˆ Algorithm (0)
      • ๐Ÿ”ฎ CS (2)
      • ๐Ÿณ Docker & Kubernetes (3)
      • ๐ŸŒˆ DEEEEEBUG (2)
      • ๐ŸŒ  etc. (8)
      • ๐Ÿ˜ผ ์‚ฌ๋‹ด (1)
  • ๋ธ”๋กœ๊ทธ ๋ฉ”๋‰ด

    • ํ™ˆ
    • ๋ฐฉ๋ช…๋ก
    • GitHub
    • Notion
    • LinkedIn
  • ๋งํฌ

    • Github
    • Notion
  • ๊ณต์ง€์‚ฌํ•ญ

    • Contact ME!
  • 250x250
  • hELLOยท Designed By์ •์ƒ์šฐ.v4.10.3
darly213
[PyTorch] pretrained model load/save, pretrained model ํŽธ์ง‘
์ƒ๋‹จ์œผ๋กœ

ํ‹ฐ์Šคํ† ๋ฆฌํˆด๋ฐ”