Meta Segment Anything
本文介紹 Meta FAIR 新出的套件 segment-anything
的實作方法 & 結果
Demo 網站
Description
Goal
- 自動化 segment 任務 (zero shot), 並可以根據 prompt (point, box, text) 進行調整
Data
- SA-1B: 高達 11M 張圖片, 1.1B 個 mask 結果 (由 SAM 生成)
Result
僅實驗自動化切割功能, 文字 prompt 功能未釋出
插畫自動切割
- 左圖:原圖
- 中圖:部分星星沒有被切割出來是因為有設定 threshold 來避免切出太小的物件, 共 109 個區域
- 右圖:挑選最大面積的 6 個物件顯示
Wafer Map 自動切割
- 左圖:原圖
- 中圖:部分文字沒有被切割出來是因為有設定 threshold 來避免切出太小的物件, 共 147 個區域
- 右圖:挑選最大面積的 4 個物件顯示
圖片來源: Development of High Power Green Light Emitting Diode Chips paper
SEM Image 自動切割
- 左圖:原圖
- 中圖:部分文字沒有被切割出來是因為有設定 threshold 來避免切出太小的物件, 共 32 個區域
- 右圖:挑選除了 Top1 以外的物件
圖片來源: DLADC: Deep Learning Based Semiconductor Wafer Surface Defects Recognition paper
Model Architecture
圖片來源: segment-anything paper
Practice
Step 1: 模型下載
SAM model 下載 download, 放入 model
資料夾中
1 2 3
| mkdir model cd model wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
Step 2: 讀入權重
import 相關套件 & load model weight
1 2 3 4 5 6 7 8
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
sam_checkpoint = "./model/sam_vit_h_4b8939.pth" model_type = "vit_h"
device = "cuda" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device)
|
Step 3: 載入圖片
用 cv2 載入圖片並轉為 array 形式
1 2 3
| import cv2 image = cv2.imread('./data/test.png') image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
Step 4: 設定模型參數 & 預測
套用 mask 生成器並設定相關參數, 實作後認為 points_per_side
, pred_iou_thresh
較為重要。
points_per_side
是控制採樣點的個數, 直接影響到輸出 mask 的質量
pred_iou_thresh
是輸出 mask 機率的 threshold
輸出結果包含了每個 mask 結果的面積大小, bounding box, mask 座標等等。
1 2 3 4 5 6 7 8 9 10
| mask_generator = SamAutomaticMaskGenerator( model=sam, points_per_side=32, pred_iou_thresh=0.9, stability_score_thresh=0.92, crop_n_layers=1, crop_n_points_downscale_factor=2, min_mask_region_area=100, ) masks2 = mask_generator.generate(image)
|
Step 5: 顯示分割結果
依照 mask 面積大小排序對原圖進行 mask 著色
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| import numpy as np import matplotlib.pyplot as plt
def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) for ann in sorted_anns: m = ann['segmentation'] img = np.ones((m.shape[0], m.shape[1], 3)) color_mask = np.random.random((1, 3)).tolist()[0] for i in range(3): img[:,:,i] = color_mask[i] ax.imshow(np.dstack((img, m*0.35)))
plt.figure(figsize=(15,15)) plt.imshow(image) show_anns(masks2) plt.axis('off') plt.show()
|
Step 6: 去背結果
建立一個 mask matrix 在乘上原本圖片 matrix 後, 即可得到去背的圖片
1 2 3 4 5 6 7 8 9 10 11 12
| final_mask = np.zeros(image.shape[:2],dtype=bool) for i in range(len(masks2)): final_mask +=masks2[i]['segmentation']
mask_image = image.copy() for i in range(3): mask_image[:,:,i] *=final_mask
plt.figure(figsize=(15,15)) plt.imshow(mask_image) plt.axis('off') plt.show()
|