The Core Idea: SDS Loss Meets 3D Gaussians#
Text-to-3D with Gaussian splatting works by representing a scene as thousands of 3D Gaussians – each one a small colored ellipsoid with a position, covariance, opacity, and color. You render these Gaussians from a random camera angle using differentiable rasterization, then use a pretrained diffusion model (Stable Diffusion) to compute a Score Distillation Sampling (SDS) loss. That loss tells you “this rendered image doesn’t look like your text prompt,” and you backpropagate through the renderer to update the Gaussian parameters. After a few thousand steps, the Gaussians arrange themselves into a coherent 3D scene that matches your prompt from every angle.
Install what you need:
1
| pip install gsplat torch diffusers transformers accelerate plyfile numpy
|
gsplat is a fast CUDA-based Gaussian splatting renderer from Nerfstudio. It handles the differentiable rasterization so you don’t have to write custom CUDA kernels.
Initializing the 3D Gaussians#
Start by creating a random point cloud. Each Gaussian gets a 3D position, a scale, a rotation quaternion, an opacity value, and RGB color. All of these are learnable parameters.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
| import torch
import torch.nn as nn
import numpy as np
device = torch.device("cuda")
num_gaussians = 10000
# Random positions in a unit sphere
positions = torch.randn(num_gaussians, 3, device=device) * 0.5
# Log-scale (exponentiated later to keep scales positive)
log_scales = torch.full((num_gaussians, 3), -3.0, device=device)
# Identity quaternions with small noise
quats = torch.zeros(num_gaussians, 4, device=device)
quats[:, 0] = 1.0
quats += torch.randn_like(quats) * 0.01
quats = quats / quats.norm(dim=-1, keepdim=True)
# Sigmoid-mapped opacities (logit space)
opacities_logit = torch.full((num_gaussians, 1), 0.5, device=device)
# RGB colors in sigmoid space
colors_logit = torch.randn(num_gaussians, 3, device=device)
# Make everything a leaf tensor requiring grad
params = {
"positions": nn.Parameter(positions),
"log_scales": nn.Parameter(log_scales),
"quats": nn.Parameter(quats),
"opacities_logit": nn.Parameter(opacities_logit),
"colors_logit": nn.Parameter(colors_logit),
}
|
The key design choice here: store scales in log-space and opacities/colors in logit-space. This way the optimizer works in unconstrained space while the actual values stay physically valid after applying exp and sigmoid.
Rendering with gsplat and Computing SDS Loss#
This is where the pieces come together. You render from a random camera, pass the rendered image to Stable Diffusion, and compute the SDS gradient.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
| import torch.nn.functional as F
from gsplat import rasterization
from diffusers import StableDiffusionPipeline
# Load Stable Diffusion for SDS guidance
sd_pipe = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=True)
# Extract the components we need for SDS
vae = sd_pipe.vae
unet = sd_pipe.unet
text_encoder = sd_pipe.text_encoder
tokenizer = sd_pipe.tokenizer
scheduler = sd_pipe.scheduler
# Encode the text prompt once
prompt = "a fantasy castle on a floating island, detailed, 3D render"
tokens = tokenizer(prompt, return_tensors="pt", padding="max_length",
max_length=77, truncation=True).input_ids.to(device)
with torch.no_grad():
text_embeddings = text_encoder(tokens)[0]
uncond_tokens = tokenizer("", return_tensors="pt", padding="max_length",
max_length=77).input_ids.to(device)
uncond_embeddings = text_encoder(uncond_tokens)[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
def random_camera(radius=2.0):
"""Generate a random camera looking at the origin."""
theta = torch.rand(1).item() * 2 * np.pi
phi = torch.rand(1).item() * np.pi * 0.6 + 0.2 # avoid poles
x = radius * np.sin(phi) * np.cos(theta)
y = radius * np.sin(phi) * np.sin(theta)
z = radius * np.cos(phi)
cam_pos = torch.tensor([x, y, z], dtype=torch.float32)
forward = -cam_pos / cam_pos.norm()
world_up = torch.tensor([0.0, 0.0, 1.0])
right = torch.cross(forward, world_up)
right = right / right.norm()
up = torch.cross(right, forward)
# Build 4x4 view matrix (world-to-camera)
viewmat = torch.eye(4, dtype=torch.float32)
viewmat[:3, 0] = right
viewmat[:3, 1] = up
viewmat[:3, 2] = -forward
viewmat[:3, 3] = cam_pos
viewmat = torch.inverse(viewmat)
return viewmat.to(device)
def render_gaussians(params, viewmat, img_size=256):
"""Render the Gaussian scene from the given viewpoint."""
scales = params["log_scales"].exp()
opacities = torch.sigmoid(params["opacities_logit"]).squeeze(-1)
colors = torch.sigmoid(params["colors_logit"])
fx = fy = img_size * 1.2
cx, cy = img_size / 2.0, img_size / 2.0
Ks = torch.tensor([[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]], dtype=torch.float32, device=device)
renders, alphas, meta = rasterization(
means=params["positions"],
quats=params["quats"] / params["quats"].norm(dim=-1, keepdim=True),
scales=scales,
opacities=opacities,
colors=colors,
viewmats=viewmat.unsqueeze(0),
Ks=Ks.unsqueeze(0),
width=img_size,
height=img_size,
packed=False,
)
# renders shape: (1, H, W, 3)
return renders[0] # (H, W, 3)
|
The rasterization function from gsplat is fully differentiable. Gradients flow from the rendered pixels all the way back to Gaussian positions, scales, rotations, and colors.
The SDS Optimization Loop#
Now wire up the training loop. For each iteration: render from a random angle, encode into latent space, add noise at a random timestep, predict the noise with the UNet, and compute the SDS gradient.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
| optimizer = torch.optim.Adam(params.values(), lr=0.005)
scheduler.set_timesteps(1000)
guidance_scale = 100.0
num_steps = 3000
img_size = 256
for step in range(num_steps):
optimizer.zero_grad()
viewmat = random_camera(radius=2.0)
rendered = render_gaussians(params, viewmat, img_size=img_size)
# Prepare image for VAE: (H,W,3) -> (1,3,H,W), scale to [-1,1]
img = rendered.permute(2, 0, 1).unsqueeze(0).clamp(0, 1)
img_512 = F.interpolate(img, size=(512, 512), mode="bilinear",
align_corners=False)
img_sd = img_512 * 2.0 - 1.0
# Encode to latent space
with torch.no_grad():
latents = vae.encode(img_sd.half()).latent_dist.sample() * 0.18215
# Random timestep
t = torch.randint(20, 980, (1,), device=device).long()
noise = torch.randn_like(latents)
noisy_latents = scheduler.add_noise(latents, noise, t)
# Predict noise with classifier-free guidance
latent_input = torch.cat([noisy_latents] * 2)
t_input = torch.cat([t] * 2)
with torch.no_grad():
noise_pred = unet(latent_input, t_input,
encoder_hidden_states=text_embeddings).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# SDS gradient: difference between predicted noise and actual noise
# We backprop this through the rendered image -> Gaussian params
w = (1 - scheduler.alphas_cumprod[t]).to(latents.dtype)
grad = w * (noise_pred - noise)
# Backprop the SDS gradient through the VAE decoder and renderer
latents.backward(gradient=grad)
optimizer.step()
if step % 500 == 0:
print(f"Step {step}/{num_steps}")
|
A few things to note: the guidance_scale for SDS is much higher than normal image generation (100 vs 7.5). This is because SDS needs a strong signal to push the 3D representation. The timestep range of 20-980 avoids the noisy extremes where the gradient signal is either too weak or too chaotic.
Once optimization finishes, save the Gaussians as a PLY file. This format is widely supported by viewers like SuperSplat and the web-based splat viewer.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
| from plyfile import PlyData, PlyElement
def export_gaussians_to_ply(params, path="scene.ply"):
"""Export Gaussian parameters to a PLY file."""
pos = params["positions"].detach().cpu().numpy()
scales = params["log_scales"].exp().detach().cpu().numpy()
quats_raw = params["quats"].detach()
quats_norm = (quats_raw / quats_raw.norm(dim=-1, keepdim=True)).cpu().numpy()
opacities = torch.sigmoid(params["opacities_logit"]).detach().cpu().numpy()
colors = (torch.sigmoid(params["colors_logit"]) * 255).detach().cpu().numpy().astype(np.uint8)
dtype = [
("x", "f4"), ("y", "f4"), ("z", "f4"),
("scale_0", "f4"), ("scale_1", "f4"), ("scale_2", "f4"),
("rot_0", "f4"), ("rot_1", "f4"), ("rot_2", "f4"), ("rot_3", "f4"),
("opacity", "f4"),
("red", "u1"), ("green", "u1"), ("blue", "u1"),
]
elements = np.empty(len(pos), dtype=dtype)
elements["x"] = pos[:, 0]
elements["y"] = pos[:, 1]
elements["z"] = pos[:, 2]
elements["scale_0"] = scales[:, 0]
elements["scale_1"] = scales[:, 1]
elements["scale_2"] = scales[:, 2]
elements["rot_0"] = quats_norm[:, 0]
elements["rot_1"] = quats_norm[:, 1]
elements["rot_2"] = quats_norm[:, 2]
elements["rot_3"] = quats_norm[:, 3]
elements["opacity"] = opacities.squeeze()
elements["red"] = colors[:, 0]
elements["green"] = colors[:, 1]
elements["blue"] = colors[:, 2]
el = PlyElement.describe(elements, "vertex")
PlyData([el]).write(path)
print(f"Exported {len(pos)} Gaussians to {path}")
export_gaussians_to_ply(params, "fantasy_castle.ply")
|
You can view the resulting PLY in SuperSplat by dragging the file into the browser. For programmatic rendering from specific angles, use the same render_gaussians function from the training loop with a fixed camera matrix and save the output with torchvision.utils.save_image.
Common Errors and Fixes#
gsplat fails to install or import:
This happens when your CUDA toolkit version doesn’t match PyTorch’s. Check both:
1
2
| nvcc --version
python -c "import torch; print(torch.version.cuda)"
|
If they mismatch, reinstall PyTorch for your local CUDA version from pytorch.org. gsplat compiles custom CUDA kernels at install time, so the versions must align.
Out of memory during SDS optimization:
The VAE encoder and UNet eat VRAM on top of your Gaussians. Reduce num_gaussians to 5000 or drop img_size to 128 (the image gets upscaled to 512 for Stable Diffusion anyway). You can also use torch.cuda.amp.autocast around the rendering step for mixed-precision, but keep the SDS gradient computation in float16 since the UNet already runs in half precision.
Scene collapses to a flat blob:
This usually means your guidance scale is too low or your learning rate is too high. Try guidance_scale=150 and lr=0.001. Also make sure you’re sampling camera positions from a full hemisphere – if all cameras point from the same direction, the optimizer only has incentive to make one side look good.
Colors look washed out:
The SDS loss can push colors toward the mean. Adding a small regularization term that penalizes low color saturation helps. Alternatively, anneal the timestep range during training – start with large timesteps (coarse structure) and shift to smaller timesteps (fine detail) over the course of optimization.