You trained a model, pushed it to a registry, and someone on the team loaded it into production. But how do you know the file wasn’t tampered with between training and serving? You don’t – unless you sign it.
Model artifact signing gives you two guarantees: integrity (the file hasn’t changed) and provenance (you know who produced it). The approach is straightforward. Hash the model file with SHA-256, sign that hash with an RSA private key, and bundle everything into a manifest. Before loading any model, verify the manifest against the public key.
Generate RSA Key Pairs#
First, generate a key pair. The private key stays on your training infrastructure. The public key goes everywhere models get loaded.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
| from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives import hashes, serialization
from pathlib import Path
def generate_key_pair(private_key_path: Path, public_key_path: Path) -> None:
"""Generate an RSA key pair and save to disk."""
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=4096,
)
# Write private key (protect this file)
private_key_path.write_bytes(
private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(b"your-passphrase"),
)
)
# Write public key (distribute freely)
public_key = private_key.public_key()
public_key_path.write_bytes(
public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
)
print(f"Keys written to {private_key_path} and {public_key_path}")
# Generate keys
generate_key_pair(Path("signing_key.pem"), Path("signing_key.pub"))
|
Use 4096-bit keys. The extra computation cost over 2048-bit is negligible for signing operations, and you get a much larger security margin. The passphrase on the private key is important – store it in a secrets manager, not in your code.
Sign a Model Artifact#
Signing works in two steps: hash the file, then sign the hash. The manifest captures everything a verifier needs.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
| import hashlib
import json
import base64
from datetime import datetime, timezone
from pathlib import Path
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import hashes, serialization
def hash_file(file_path: Path) -> str:
"""Compute SHA-256 hash of a file."""
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
while chunk := f.read(8192):
sha256.update(chunk)
return sha256.hexdigest()
def sign_artifact(
model_path: Path,
private_key_path: Path,
passphrase: bytes,
signer_identity: str,
) -> dict:
"""Sign a model artifact and return a manifest dict."""
# Load private key
private_key = serialization.load_pem_private_key(
private_key_path.read_bytes(),
password=passphrase,
)
# Hash the model file
file_hash = hash_file(model_path)
# Sign the hash
signature = private_key.sign(
file_hash.encode("utf-8"),
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
hashes.SHA256(),
)
manifest = {
"model_file": model_path.name,
"sha256": file_hash,
"signature": base64.b64encode(signature).decode("utf-8"),
"signer": signer_identity,
"signed_at": datetime.now(timezone.utc).isoformat(),
"algorithm": "RSA-PSS-SHA256",
"key_size": 4096,
}
return manifest
def save_manifest(manifest: dict, output_path: Path) -> None:
"""Write the signing manifest to a JSON file."""
output_path.write_text(json.dumps(manifest, indent=2))
print(f"Manifest saved to {output_path}")
|
A few things worth noting. We use RSA-PSS padding instead of PKCS1v15 – PSS is the modern standard and provides a stronger security proof. The manifest includes the algorithm and key size so verifiers know exactly what to expect. The timestamp uses UTC to avoid timezone ambiguity.
Verify Before Loading#
The verification function is what runs in your serving infrastructure. It takes the manifest, the model file, and the public key. If anything is off, it raises an exception.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
| from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.exceptions import InvalidSignature
from pathlib import Path
import base64
import json
def verify_artifact(
model_path: Path,
manifest_path: Path,
public_key_path: Path,
) -> dict:
"""Verify a model artifact against its manifest. Raises on failure."""
manifest = json.loads(manifest_path.read_text())
# Step 1: Check the file hash matches
actual_hash = hash_file(model_path)
if actual_hash != manifest["sha256"]:
raise ValueError(
f"Hash mismatch: expected {manifest['sha256']}, got {actual_hash}"
)
# Step 2: Verify the signature
public_key = serialization.load_pem_public_key(
public_key_path.read_bytes()
)
signature = base64.b64decode(manifest["signature"])
try:
public_key.verify(
signature,
manifest["sha256"].encode("utf-8"),
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
hashes.SHA256(),
)
except InvalidSignature:
raise ValueError(
f"Invalid signature for {model_path.name} -- "
f"claimed signer: {manifest['signer']}"
)
print(f"Verified: {model_path.name} signed by {manifest['signer']} at {manifest['signed_at']}")
return manifest
|
The order matters. Check the hash first, then verify the signature. If the hash doesn’t match, the file was modified and there’s no point checking the signature. If the hash matches but the signature is invalid, someone recomputed the hash after tampering.
Integrate with PyTorch Model Saving and Loading#
Here’s how to wire signing into your actual PyTorch workflow. The save_signed_model function saves the state dict and produces a manifest. The load_verified_model function refuses to load anything that doesn’t pass verification.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
| import torch
import torch.nn as nn
from pathlib import Path
class SimpleClassifier(nn.Module):
def __init__(self, input_dim: int, num_classes: int):
super().__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
def save_signed_model(
model: nn.Module,
model_path: Path,
private_key_path: Path,
passphrase: bytes,
signer_identity: str,
) -> Path:
"""Save a PyTorch model and create a signed manifest."""
torch.save(model.state_dict(), model_path)
manifest = sign_artifact(model_path, private_key_path, passphrase, signer_identity)
manifest_path = model_path.with_suffix(".manifest.json")
save_manifest(manifest, manifest_path)
return manifest_path
def load_verified_model(
model: nn.Module,
model_path: Path,
public_key_path: Path,
) -> nn.Module:
"""Load a PyTorch model only if verification passes."""
manifest_path = model_path.with_suffix(".manifest.json")
if not manifest_path.exists():
raise FileNotFoundError(f"No manifest found at {manifest_path}")
# This raises ValueError if verification fails
verify_artifact(model_path, manifest_path, public_key_path)
model.load_state_dict(torch.load(model_path, weights_only=True))
model.eval()
return model
# --- Full example ---
# Training side
model = SimpleClassifier(input_dim=768, num_classes=10)
# ... training happens here ...
save_signed_model(
model=model,
model_path=Path("classifier_v1.pt"),
private_key_path=Path("signing_key.pem"),
passphrase=b"your-passphrase",
signer_identity="training-pipeline@ml-team",
)
# Serving side
served_model = SimpleClassifier(input_dim=768, num_classes=10)
served_model = load_verified_model(
model=served_model,
model_path=Path("classifier_v1.pt"),
public_key_path=Path("signing_key.pub"),
)
|
The weights_only=True flag in torch.load is critical. Without it, PyTorch uses pickle under the hood, which can execute arbitrary code. Even with signature verification, you should always set this flag.
Batch Signing Multiple Artifacts#
Real training pipelines produce more than one file. You might have the model weights, a tokenizer config, and a preprocessing pipeline. Sign them all and bundle the manifests.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
| def sign_artifact_bundle(
artifact_paths: list[Path],
private_key_path: Path,
passphrase: bytes,
signer_identity: str,
bundle_output: Path,
) -> None:
"""Sign multiple artifacts and write a single bundle manifest."""
manifests = []
for path in artifact_paths:
manifest = sign_artifact(path, private_key_path, passphrase, signer_identity)
manifests.append(manifest)
bundle = {
"artifacts": manifests,
"bundle_signed_at": datetime.now(timezone.utc).isoformat(),
"artifact_count": len(manifests),
}
bundle_output.write_text(json.dumps(bundle, indent=2))
print(f"Bundle manifest with {len(manifests)} artifacts saved to {bundle_output}")
# Sign model + config together
sign_artifact_bundle(
artifact_paths=[Path("classifier_v1.pt"), Path("config.json")],
private_key_path=Path("signing_key.pem"),
passphrase=b"your-passphrase",
signer_identity="training-pipeline@ml-team",
bundle_output=Path("release_bundle.manifest.json"),
)
|
Common Errors and Fixes#
ValueError: Could not deserialize key data when loading the private key. This usually means the passphrase is wrong or the file is corrupted. Double-check your passphrase. If you saved the key without encryption, pass password=None to load_pem_private_key.
InvalidSignature during verification when you’re sure the file hasn’t changed. The most common cause is signing with one padding scheme and verifying with another. Make sure both sides use padding.PSS with identical parameters. Another cause: the signed data doesn’t match exactly. If you sign the hex hash string, verify against the hex hash string – not the raw bytes.
TypeError: a bytes-like object is required when calling sign() or verify(). The data argument must be bytes, not str. Always encode strings with .encode("utf-8") before passing them to cryptographic functions.
Hash mismatch on large files. If you’re reading the file in chunks (as we do above), make sure nothing else is writing to the file concurrently. On network filesystems, stale reads can also cause mismatches. Copy the file locally before hashing if your storage layer doesn’t guarantee read consistency.
torch.load warnings about pickle. PyTorch 2.6+ defaults to weights_only=True, but older versions don’t. Always pass it explicitly. If your checkpoint contains custom objects beyond the state dict, you’ll need to allowlist them with torch.serialization.add_safe_globals().
Permission denied on key files. Private keys should be chmod 600 – readable only by the owner. If your CI runner uses a different user, make sure the key file permissions are set correctly in your pipeline configuration.
Key Rotation Strategy#
Don’t use the same key pair forever. Rotate keys on a schedule – quarterly is a reasonable starting point. When you rotate, keep the old public key around so you can still verify models signed before the rotation. A simple approach: name your keys with a version suffix like signing_key_v2.pem and include a key_version field in your manifest.