상세 컨텐츠

본문 제목

Detectron2 활용법 (3) - Model setup

Machine Learning

by 땅콩또복 2022. 10. 14. 21:54

본문

여기서 다뤄볼 핵심은 크게 4가지다

  1. Use models
  2. Write models
  3. Training
  4. Evaluation

 

1. Model architecture 만들기

Model architecture 설계를 위해 크게 3가지를 사용한다.

  • build_model: Backbone과 Head가 함께 있는 Architecture
  • build_backbone: Backbone만 만들기
  • build_roi_heads: Head만 만들기
from detectron2.modeling import build_model
model = build_model(cfg)  # returns a torch.nn.Module

 

2. Checkpoint load & save 하기

 

앞에서 만들어진 model architecture에 pretrained weight를 입혀주는 단계 & 학습이 끝난 이후 save 하는 단계이다.

Load 과정에서 사용할 수 있는 Format은 pth와 pkl 파일 이다. 각각,

  • .pth: torch.{load, save}
  • .pkl: pickle.{dump,load}
from detectron2.checkpoint import DetectionCheckpointer

# Checkpoint load 하는 단계, usually from cfg.MODEL.WEIGHTS
DetectionCheckpointer(model).load(file_path_or_url)  

# Checkpoint save 하는 단계
checkpointer = DetectionCheckpointer(model, save_dir="output")
checkpointer.save("model_999")  # save to output/model_999.pth

 

3. Model 사용 하기

3-1. Training 단계에서 Model 사용법

training 단계에서는 항상 EventStorage안에서 forward processing이 이뤄져야 한다.

from detectron2.utils.events import EventStorage
with EventStorage() as storage:
  losses = model(inputs)

3-2. Inference 단계에서 Model 사용법

두가지 방법이 있다

  1. DefaultPredictor를 사용 (simple inference 방법)
  2. pytorch 처럼 직접 inference 수행 (하단 확인)
model.eval()
with torch.no_grad():
  outputs = model(inputs)

 

4. (Standard) Model Input Format

직접 해보자. 하나하나 설명하기에는 너무 많다. 핵심만 보자.

기본적으로 input은 list[dict] 형태로 되어 있고 각각의 dict하나의 image에 대한 정보를 담고 있다.

Dict는 다음과 같은 keys를 포함하고 있다.

Key value Type Description
"image" tensor image (C, H, W)
"height", "width"   image size in inference (즉, 항상 training에 사용한 사이즈와 같으 필요가 없음)
"instances"   instance object for training (including gt_boxes, gt_classes, gt_masks, gt_keypoints)
"sem_seg" tensor[int] semantic segmentation ground truth in (H, W) format. 값은 label (from 0)
"proposals"   instance object only for Fas R-CNN style models

5. (Standard) Model Output Format

직접 해보자. 하나하나 설명하기에는 너무 많다. 핵심만 보자.

Dict는 다음과 같은 keys를 포함하고 있다.

Key value
Type Description
"instances"   instance object (including “pred_boxes”, "scores", "pred_classes", “pred_masks”,
“pred_keypoints”)
"sem_seg" tensor semantic segmentation prediction (num_categories, H, W)
"proposals"   instance object (including “proposal_boxes” and “objectness_logits”)
“panoptic_seg” pred   segment id of each pixel
segments_info   segment information (including "id", "isthing", and "category_id")

 

관련글 더보기

댓글 영역