-
모델 앙상블 기법 - NMS, WBF개발자노트/네이버 부스트캠프 AI 2024. 1. 16. 11:47
딥러닝 결과의 성능을 높이기 위해서 여러 모델의 예측결과를 앙상블 한다.
앙상블에도 여러가지 방법이 있는데 NMS, WBF를 사용하고 정리해본다.
Non-Maximum Suppression (NMS)
겹치는 부분의 정도가 일정 threshold 이상인 박스들 중에서 가장 확률이 높은 박스를 선택하고
나머지 겹치는 박스를 제거하는 기법
장점 : 간단함. 중복 제거를 효과적으로 할 수 있음.
단점 : 가장 확률이 높은 박스를 선택하기 때문에 다양한 객체를 탐지하기 어려움, 임계값에 민감하여 임계값 설정이 중요함.
Weighted Boxes Fusion (WBF)
각 모델의 예측에서 박스의 좌표를 가중 평균하여 최종 예측을 생성함.
각 박스에 대한 예측 점수를 가중치로 사용하여 모델 간의 성능 차이를 고려함.
장점 : 다양한 모델의 장점을 효과적으로 결합할 수 있음.
단점 : 계산 비용 증가. 가중치 선택의 기준이 있어야 성능 향상을 기대할 수 있음.
참고 사항
대회 제출 포맷에 맞게 to_csv의 매개변수 설정을 해준다.
제출전에는 꼭 포맷을 확인한다.
NMS 코드 예시
더보기prediction_strings = [] file_names = [] # ensemble 시 설정할 iou threshold 이 부분을 바꿔가며 대회 metric에 알맞게 적용해봐요! iou_thr = 0.5 # 각 image id 별로 submission file에서 box좌표 추출 for i, image_id in enumerate(image_ids): prediction_string = '' boxes_list = [] scores_list = [] labels_list = [] index_str = image_id.split('/')[1].split('.')[0] image_index = int(index_str) image_info = coco.loadImgs(image_index)[0] # 각 submission file 별로 prediction box좌표 불러오기 for df in submission_df: predict_string = df[df['image_id'] == image_id]['PredictionString'].tolist()[0] predict_list = str(predict_string).split() if len(predict_list)==0 or len(predict_list)==1: continue predict_list = np.reshape(predict_list, (-1, 6)) box_list = [] for box in predict_list[:, 2:6].tolist(): box[0] = float(box[0]) / image_info['width'] box[1] = float(box[1]) / image_info['height'] box[2] = float(box[2]) / image_info['width'] box[3] = float(box[3]) / image_info['height'] box_list.append(box) boxes_list.append(box_list) scores_list.append(list(map(float, predict_list[:, 1].tolist()))) labels_list.append(list(map(int, predict_list[:, 0].tolist()))) # 예측 box가 있다면 이를 ensemble 수행 if len(boxes_list): boxes, scores, labels = nms(boxes_list, scores_list, labels_list, iou_thr=iou_thr) for box, score, label in zip(boxes, scores, labels): prediction_string += str(label) + ' ' + str(score) + ' ' + str(box[0] * image_info['width']) + ' ' + str(box[1] * image_info['height']) + ' ' + str(box[2] * image_info['width']) + ' ' + str(box[3] * image_info['height']) + ' ' prediction_strings.append(prediction_string) file_names.append(image_id)
wbf 코드 예시, (주의 : class label이 int인지 float인지 확인하고 그에 맞게 수정해야 한다. 보통 class는 int)
더보기from ensemble_boxes import weighted_boxes_fusion prediction_strings = [] file_names = [] # ensemble 시 설정할 iou threshold 이 부분을 바꿔가며 대회 metric에 알맞게 적용해봐요! iou_thr = 0.6 # 각 image id 별로 submission file에서 box좌표 추출 for i, image_id in enumerate(image_ids): prediction_string = '' boxes_list = [] scores_list = [] labels_list = [] index_str = image_id.split('/')[1].split('.')[0] image_index = int(index_str) image_info = coco.loadImgs(image_index)[0] # 각 submission file 별로 prediction box좌표 불러오기 for df in submission_df: predict_string = df[df['image_id'] == image_id]['PredictionString'].tolist()[0] predict_list = str(predict_string).split() if len(predict_list) == 0 or len(predict_list) == 1: continue predict_list = np.reshape(predict_list, (-1, 6)) box_list = [] for box in predict_list[:, 2:6].tolist(): box[0] = float(box[0]) / image_info['width'] box[1] = float(box[1]) / image_info['height'] box[2] = float(box[2]) / image_info['width'] box[3] = float(box[3]) / image_info['height'] box_list.append(box) boxes_list.append(box_list) scores_list.append(list(map(float, predict_list[:, 1].tolist()))) labels_list.append(list(map(int, predict_list[:, 0].tolist()))) # 예측 box가 있다면 이를 ensemble 수행 if len(boxes_list): # WBF 수행 boxes, scores, labels = weighted_boxes_fusion(boxes_list, scores_list, labels_list, iou_thr=iou_thr) for box, score, label in zip(boxes, scores, labels): if score < 0.05: continue prediction_string += str(int(label)) + ' ' + str(score) + ' ' + str(box[0] * image_info['width']) + ' ' + str(box[1] * image_info['height']) + ' ' + str(box[2] * image_info['width']) + ' ' + str(box[3] * image_info['height']) + ' ' prediction_strings.append(prediction_string) file_names.append(image_id)
'개발자노트 > 네이버 부스트캠프 AI' 카테고리의 다른 글
[네부캠 AI tech] 11주차 주간회고 (01/15~01/19) (0) 2024.01.19 나의 삶의 지도 '개발자 진화 과정' (0) 2024.01.16 [네부캠 AI tech] 10주차 주간회고 (01/08~01/12) (1) 2024.01.12 (작성중) Grad CAM (0) 2024.01.12 [wandb] 팀 세팅, mmdetection 3에서 팀 초대하는 방법 (0) 2024.01.12