SAM 2 gives you pixel-perfect segmentation masks from a single click. The model takes point or box prompts and returns precise object boundaries in milliseconds on a GPU. Pair that with a WebSocket connection and you get an interactive annotation tool where users click on objects in the browser and instantly see segmentation overlays.

Here is what you need installed:

1
2
pip install 'git+https://github.com/facebookresearch/sam2.git'
pip install fastapi uvicorn websockets pillow numpy

You will also need a GPU with at least 8 GB of VRAM for the large model, or 4 GB for the tiny variant.

Loading SAM 2 for Inference

The SAM2ImagePredictor class handles embedding computation and mask prediction. Load it once at startup and reuse it across requests. The from_pretrained method pulls weights directly from Hugging Face, so you do not need to manage checkpoint files manually.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import numpy as np
from PIL import Image
from sam2.sam2_image_predictor import SAM2ImagePredictor

# Load the model once -- this downloads weights on first run
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")

# Load and set an image (HWC format, uint8, RGB)
image = np.array(Image.open("photo.jpg").convert("RGB"))

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    predictor.set_image(image)

    # Single point prompt: (x, y) coordinates, label 1 = foreground
    point_coords = np.array([[350, 250]])
    point_labels = np.array([1])

    masks, scores, logits = predictor.predict(
        point_coords=point_coords,
        point_labels=point_labels,
        multimask_output=False,
    )

# masks shape: (1, H, W) -- boolean array
print(f"Mask shape: {masks.shape}, Score: {scores[0]:.3f}")

The set_image call computes the image embedding. This is the expensive step – around 100-200 ms on an RTX 3090. After that, each predict call takes only 5-10 ms because it just runs the lightweight mask decoder.

Choosing the Right Model Size

SAM 2 ships in four sizes. For real-time annotation, the tiny model gives you sub-50 ms embeddings at a small quality cost:

  • facebook/sam2-hiera-tiny – fastest, good for interactive tools
  • facebook/sam2-hiera-small – balanced speed and accuracy
  • facebook/sam2-hiera-base-plus – higher quality masks
  • facebook/sam2-hiera-large – best quality, slower embedding

Building the WebSocket Server

The server needs to load SAM 2 once at startup, accept image uploads and click coordinates over WebSocket, and stream back segmentation masks. FastAPI’s lifespan context manager is the right pattern for model initialization.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import io
import json
import base64
from contextlib import asynccontextmanager

import torch
import numpy as np
from PIL import Image
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from sam2.sam2_image_predictor import SAM2ImagePredictor


@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup: load SAM 2 model
    app.state.predictor = SAM2ImagePredictor.from_pretrained(
        "facebook/sam2-hiera-large"
    )
    app.state.device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"SAM 2 loaded on {app.state.device}")
    yield
    # Shutdown: free GPU memory
    del app.state.predictor
    torch.cuda.empty_cache()


app = FastAPI(lifespan=lifespan)


@app.websocket("/ws/segment")
async def segment_ws(websocket: WebSocket):
    await websocket.accept()
    predictor = app.state.predictor
    current_image = None

    try:
        while True:
            raw = await websocket.receive_text()
            message = json.loads(raw)
            action = message.get("action")

            if action == "set_image":
                # Decode base64 image from the browser
                img_bytes = base64.b64decode(message["image_b64"])
                pil_image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                current_image = np.array(pil_image)

                with torch.inference_mode(), torch.autocast(
                    "cuda", dtype=torch.bfloat16
                ):
                    predictor.set_image(current_image)

                await websocket.send_json(
                    {"status": "ready", "width": current_image.shape[1],
                     "height": current_image.shape[0]}
                )

            elif action == "click":
                if current_image is None:
                    await websocket.send_json({"error": "No image set"})
                    continue

                # Extract click coordinates and label
                x = message["x"]
                y = message["y"]
                label = message.get("label", 1)  # 1=foreground, 0=background

                point_coords = np.array([[x, y]], dtype=np.float32)
                point_labels = np.array([label], dtype=np.int32)

                with torch.inference_mode(), torch.autocast(
                    "cuda", dtype=torch.bfloat16
                ):
                    masks, scores, _ = predictor.predict(
                        point_coords=point_coords,
                        point_labels=point_labels,
                        multimask_output=False,
                    )

                # Encode mask as base64 PNG
                mask_image = Image.fromarray((masks[0] * 255).astype(np.uint8))
                buf = io.BytesIO()
                mask_image.save(buf, format="PNG")
                mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")

                await websocket.send_json({
                    "mask_b64": mask_b64,
                    "score": float(scores[0]),
                    "width": masks.shape[2],
                    "height": masks.shape[1],
                })

    except WebSocketDisconnect:
        print("Client disconnected")

Run it with:

1
uvicorn server:app --host 0.0.0.0 --port 8000

Processing Point Prompts

The browser sends raw pixel coordinates from click events. SAM 2 expects point_coords as a numpy array of shape (N, 2) in (x, y) format, and point_labels as shape (N,) where 1 means foreground and 0 means background.

Multi-point prompts give better results. When the user clicks multiple times, accumulate the points and send them all at once:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def build_prompt(clicks: list[dict]) -> tuple[np.ndarray, np.ndarray]:
    """Convert a list of browser clicks to SAM 2 prompt arrays.

    Each click is {"x": int, "y": int, "label": int}.
    Label 1 = include this object, 0 = exclude this region.
    """
    coords = np.array([[c["x"], c["y"]] for c in clicks], dtype=np.float32)
    labels = np.array([c["label"] for c in clicks], dtype=np.int32)
    return coords, labels


# Example: user clicked foreground at (350, 250), background at (100, 50)
clicks = [
    {"x": 350, "y": 250, "label": 1},
    {"x": 100, "y": 50, "label": 0},
]
point_coords, point_labels = build_prompt(clicks)

with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
    masks, scores, logits = predictor.predict(
        point_coords=point_coords,
        point_labels=point_labels,
        multimask_output=True,
    )

# Pick the mask with the highest IoU score
best_idx = np.argmax(scores)
best_mask = masks[best_idx]  # shape: (H, W), boolean
print(f"Best mask score: {scores[best_idx]:.3f}")

Setting multimask_output=True returns three candidate masks. The one with the highest score is usually best, but you can show all three and let the user pick.

Encoding and Streaming Masks

Raw boolean masks are large. A 1920x1080 mask is over 2 MB uncompressed. You need efficient encoding to keep latency low.

Base64 PNG

PNG compression works well for binary masks. Typical compression ratios hit 10-50x on segmentation masks because they have large uniform regions:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import io
import base64
from PIL import Image
import numpy as np


def mask_to_base64_png(mask: np.ndarray) -> str:
    """Encode a boolean mask as a base64 PNG string.

    Args:
        mask: Boolean numpy array of shape (H, W).

    Returns:
        Base64-encoded PNG string.
    """
    # Convert boolean mask to uint8 (0 or 255)
    mask_uint8 = (mask.astype(np.uint8)) * 255
    img = Image.fromarray(mask_uint8, mode="L")

    buf = io.BytesIO()
    img.save(buf, format="PNG", optimize=True)
    return base64.b64encode(buf.getvalue()).decode("utf-8")


def base64_png_to_mask(b64_string: str) -> np.ndarray:
    """Decode a base64 PNG string back to a boolean mask."""
    raw = base64.b64decode(b64_string)
    img = Image.open(io.BytesIO(raw)).convert("L")
    return np.array(img) > 127

Run-Length Encoding

For even smaller payloads, run-length encoding (RLE) compresses binary masks to a few hundred bytes:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def mask_to_rle(mask: np.ndarray) -> dict:
    """Encode a binary mask using run-length encoding.

    Args:
        mask: Boolean numpy array of shape (H, W).

    Returns:
        Dict with 'counts' (list of run lengths) and 'size' [H, W].
    """
    flat = mask.flatten(order="F")  # column-major, COCO convention
    diff = np.diff(np.concatenate([[0], flat, [0]]))
    starts = np.where(diff != 0)[0]
    counts = np.diff(starts).tolist()
    return {"counts": counts, "size": list(mask.shape)}


def rle_to_mask(rle: dict) -> np.ndarray:
    """Decode an RLE dict back to a boolean mask."""
    h, w = rle["size"]
    flat = np.zeros(h * w, dtype=bool)
    position = 0
    for i, count in enumerate(rle["counts"]):
        if i % 2 == 1:  # odd runs are foreground
            flat[position : position + count] = True
        position += count
    return flat.reshape((h, w), order="F")

For the WebSocket server, base64 PNG is the pragmatic choice. The browser can display it directly as an image overlay without any decoding library.

Common Errors and Fixes

CUDA Out of Memory

This happens when you load the large model on a GPU with limited VRAM, or when multiple images get their embeddings cached simultaneously.

1
2
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB.
GPU 0 has a total capacity of 7.79 GiB of which 112.50 MiB is free.

Fix: switch to a smaller model and call reset_predictor() between images to free cached embeddings:

1
2
3
4
5
6
# Use the tiny model for lower memory footprint
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")

# After processing each image, reset to free the cached embedding
predictor.reset_predictor()
torch.cuda.empty_cache()

WebSocket Disconnection During Inference

If the client disconnects while SAM 2 is running inference, the coroutine raises WebSocketDisconnect. The GPU memory from that inference call may not get freed cleanly.

1
starlette.websockets.WebSocketDisconnect: 1001

Fix: wrap inference in a try/except and always clean up:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
@app.websocket("/ws/segment")
async def segment_ws(websocket: WebSocket):
    await websocket.accept()
    predictor = app.state.predictor
    try:
        while True:
            raw = await websocket.receive_text()
            message = json.loads(raw)
            # ... handle message ...
    except WebSocketDisconnect:
        predictor.reset_predictor()
        print("Client disconnected, predictor state cleared")
    except Exception as e:
        predictor.reset_predictor()
        await websocket.close(code=1011, reason=str(e))

Image Size Mismatch

Browser canvases often scale images for display. If the user clicks on a scaled canvas, the coordinates will not match the original image dimensions. You get masks that segment the wrong region.

1
2
# No error raised, but the mask covers the wrong area entirely.
# The click at (200, 150) on a 50%-scaled canvas should map to (400, 300).

Fix: send the original image dimensions along with the click, and scale coordinates on the server:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
def scale_coords(
    click_x: int,
    click_y: int,
    canvas_width: int,
    canvas_height: int,
    image_width: int,
    image_height: int,
) -> tuple[float, float]:
    """Map canvas click coordinates to original image pixel coordinates."""
    scale_x = image_width / canvas_width
    scale_y = image_height / canvas_height
    return click_x * scale_x, click_y * scale_y

Always validate that the scaled coordinates fall within the image bounds before passing them to predictor.predict().