SAM 2 gives you pixel-perfect segmentation masks from a single click. The model takes point or box prompts and returns precise object boundaries in milliseconds on a GPU. Pair that with a WebSocket connection and you get an interactive annotation tool where users click on objects in the browser and instantly see segmentation overlays.
Here is what you need installed:
1
2
| pip install 'git+https://github.com/facebookresearch/sam2.git'
pip install fastapi uvicorn websockets pillow numpy
|
You will also need a GPU with at least 8 GB of VRAM for the large model, or 4 GB for the tiny variant.
Loading SAM 2 for Inference#
The SAM2ImagePredictor class handles embedding computation and mask prediction. Load it once at startup and reuse it across requests. The from_pretrained method pulls weights directly from Hugging Face, so you do not need to manage checkpoint files manually.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
| import torch
import numpy as np
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor
# Load the model once -- this downloads weights on first run
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
# Load and set an image (HWC format, uint8, RGB)
image = np.array(Image.open("photo.jpg").convert("RGB"))
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
predictor.set_image(image)
# Single point prompt: (x, y) coordinates, label 1 = foreground
point_coords = np.array([[350, 250]])
point_labels = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=False,
)
# masks shape: (1, H, W) -- boolean array
print(f"Mask shape: {masks.shape}, Score: {scores[0]:.3f}")
|
The set_image call computes the image embedding. This is the expensive step – around 100-200 ms on an RTX 3090. After that, each predict call takes only 5-10 ms because it just runs the lightweight mask decoder.
Choosing the Right Model Size#
SAM 2 ships in four sizes. For real-time annotation, the tiny model gives you sub-50 ms embeddings at a small quality cost:
facebook/sam2-hiera-tiny – fastest, good for interactive toolsfacebook/sam2-hiera-small – balanced speed and accuracyfacebook/sam2-hiera-base-plus – higher quality masksfacebook/sam2-hiera-large – best quality, slower embedding
Building the WebSocket Server#
The server needs to load SAM 2 once at startup, accept image uploads and click coordinates over WebSocket, and stream back segmentation masks. FastAPI’s lifespan context manager is the right pattern for model initialization.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
| import io
import json
import base64
from contextlib import asynccontextmanager
import torch
import numpy as np
from PIL import Image
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from sam2.sam2_image_predictor import SAM2ImagePredictor
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: load SAM 2 model
app.state.predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-large"
)
app.state.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"SAM 2 loaded on {app.state.device}")
yield
# Shutdown: free GPU memory
del app.state.predictor
torch.cuda.empty_cache()
app = FastAPI(lifespan=lifespan)
@app.websocket("/ws/segment")
async def segment_ws(websocket: WebSocket):
await websocket.accept()
predictor = app.state.predictor
current_image = None
try:
while True:
raw = await websocket.receive_text()
message = json.loads(raw)
action = message.get("action")
if action == "set_image":
# Decode base64 image from the browser
img_bytes = base64.b64decode(message["image_b64"])
pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
current_image = np.array(pil_image)
with torch.inference_mode(), torch.autocast(
"cuda", dtype=torch.bfloat16
):
predictor.set_image(current_image)
await websocket.send_json(
{"status": "ready", "width": current_image.shape[1],
"height": current_image.shape[0]}
)
elif action == "click":
if current_image is None:
await websocket.send_json({"error": "No image set"})
continue
# Extract click coordinates and label
x = message["x"]
y = message["y"]
label = message.get("label", 1) # 1=foreground, 0=background
point_coords = np.array([[x, y]], dtype=np.float32)
point_labels = np.array([label], dtype=np.int32)
with torch.inference_mode(), torch.autocast(
"cuda", dtype=torch.bfloat16
):
masks, scores, _ = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=False,
)
# Encode mask as base64 PNG
mask_image = Image.fromarray((masks[0] * 255).astype(np.uint8))
buf = io.BytesIO()
mask_image.save(buf, format="PNG")
mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
await websocket.send_json({
"mask_b64": mask_b64,
"score": float(scores[0]),
"width": masks.shape[2],
"height": masks.shape[1],
})
except WebSocketDisconnect:
print("Client disconnected")
|
Run it with:
1
| uvicorn server:app --host 0.0.0.0 --port 8000
|
Processing Point Prompts#
The browser sends raw pixel coordinates from click events. SAM 2 expects point_coords as a numpy array of shape (N, 2) in (x, y) format, and point_labels as shape (N,) where 1 means foreground and 0 means background.
Multi-point prompts give better results. When the user clicks multiple times, accumulate the points and send them all at once:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
| def build_prompt(clicks: list[dict]) -> tuple[np.ndarray, np.ndarray]:
"""Convert a list of browser clicks to SAM 2 prompt arrays.
Each click is {"x": int, "y": int, "label": int}.
Label 1 = include this object, 0 = exclude this region.
"""
coords = np.array([[c["x"], c["y"]] for c in clicks], dtype=np.float32)
labels = np.array([c["label"] for c in clicks], dtype=np.int32)
return coords, labels
# Example: user clicked foreground at (350, 250), background at (100, 50)
clicks = [
{"x": 350, "y": 250, "label": 1},
{"x": 100, "y": 50, "label": 0},
]
point_coords, point_labels = build_prompt(clicks)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
masks, scores, logits = predictor.predict(
point_coords=point_coords,
point_labels=point_labels,
multimask_output=True,
)
# Pick the mask with the highest IoU score
best_idx = np.argmax(scores)
best_mask = masks[best_idx] # shape: (H, W), boolean
print(f"Best mask score: {scores[best_idx]:.3f}")
|
Setting multimask_output=True returns three candidate masks. The one with the highest score is usually best, but you can show all three and let the user pick.
Encoding and Streaming Masks#
Raw boolean masks are large. A 1920x1080 mask is over 2 MB uncompressed. You need efficient encoding to keep latency low.
Base64 PNG#
PNG compression works well for binary masks. Typical compression ratios hit 10-50x on segmentation masks because they have large uniform regions:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
| import io
import base64
from PIL import Image
import numpy as np
def mask_to_base64_png(mask: np.ndarray) -> str:
"""Encode a boolean mask as a base64 PNG string.
Args:
mask: Boolean numpy array of shape (H, W).
Returns:
Base64-encoded PNG string.
"""
# Convert boolean mask to uint8 (0 or 255)
mask_uint8 = (mask.astype(np.uint8)) * 255
img = Image.fromarray(mask_uint8, mode="L")
buf = io.BytesIO()
img.save(buf, format="PNG", optimize=True)
return base64.b64encode(buf.getvalue()).decode("utf-8")
def base64_png_to_mask(b64_string: str) -> np.ndarray:
"""Decode a base64 PNG string back to a boolean mask."""
raw = base64.b64decode(b64_string)
img = Image.open(io.BytesIO(raw)).convert("L")
return np.array(img) > 127
|
Run-Length Encoding#
For even smaller payloads, run-length encoding (RLE) compresses binary masks to a few hundred bytes:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
| def mask_to_rle(mask: np.ndarray) -> dict:
"""Encode a binary mask using run-length encoding.
Args:
mask: Boolean numpy array of shape (H, W).
Returns:
Dict with 'counts' (list of run lengths) and 'size' [H, W].
"""
flat = mask.flatten(order="F") # column-major, COCO convention
diff = np.diff(np.concatenate([[0], flat, [0]]))
starts = np.where(diff != 0)[0]
counts = np.diff(starts).tolist()
return {"counts": counts, "size": list(mask.shape)}
def rle_to_mask(rle: dict) -> np.ndarray:
"""Decode an RLE dict back to a boolean mask."""
h, w = rle["size"]
flat = np.zeros(h * w, dtype=bool)
position = 0
for i, count in enumerate(rle["counts"]):
if i % 2 == 1: # odd runs are foreground
flat[position : position + count] = True
position += count
return flat.reshape((h, w), order="F")
|
For the WebSocket server, base64 PNG is the pragmatic choice. The browser can display it directly as an image overlay without any decoding library.
Common Errors and Fixes#
CUDA Out of Memory#
This happens when you load the large model on a GPU with limited VRAM, or when multiple images get their embeddings cached simultaneously.
1
2
| torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB.
GPU 0 has a total capacity of 7.79 GiB of which 112.50 MiB is free.
|
Fix: switch to a smaller model and call reset_predictor() between images to free cached embeddings:
1
2
3
4
5
6
| # Use the tiny model for lower memory footprint
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")
# After processing each image, reset to free the cached embedding
predictor.reset_predictor()
torch.cuda.empty_cache()
|
WebSocket Disconnection During Inference#
If the client disconnects while SAM 2 is running inference, the coroutine raises WebSocketDisconnect. The GPU memory from that inference call may not get freed cleanly.
1
| starlette.websockets.WebSocketDisconnect: 1001
|
Fix: wrap inference in a try/except and always clean up:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
| @app.websocket("/ws/segment")
async def segment_ws(websocket: WebSocket):
await websocket.accept()
predictor = app.state.predictor
try:
while True:
raw = await websocket.receive_text()
message = json.loads(raw)
# ... handle message ...
except WebSocketDisconnect:
predictor.reset_predictor()
print("Client disconnected, predictor state cleared")
except Exception as e:
predictor.reset_predictor()
await websocket.close(code=1011, reason=str(e))
|
Image Size Mismatch#
Browser canvases often scale images for display. If the user clicks on a scaled canvas, the coordinates will not match the original image dimensions. You get masks that segment the wrong region.
1
2
| # No error raised, but the mask covers the wrong area entirely.
# The click at (200, 150) on a 50%-scaled canvas should map to (400, 300).
|
Fix: send the original image dimensions along with the click, and scale coordinates on the server:
1
2
3
4
5
6
7
8
9
10
11
12
| def scale_coords(
click_x: int,
click_y: int,
canvas_width: int,
canvas_height: int,
image_width: int,
image_height: int,
) -> tuple[float, float]:
"""Map canvas click coordinates to original image pixel coordinates."""
scale_x = image_width / canvas_width
scale_y = image_height / canvas_height
return click_x * scale_x, click_y * scale_y
|
Always validate that the scaled coordinates fall within the image bounds before passing them to predictor.predict().