The Core Idea: SDS Loss Meets 3D Gaussians

Text-to-3D with Gaussian splatting works by representing a scene as thousands of 3D Gaussians – each one a small colored ellipsoid with a position, covariance, opacity, and color. You render these Gaussians from a random camera angle using differentiable rasterization, then use a pretrained diffusion model (Stable Diffusion) to compute a Score Distillation Sampling (SDS) loss. That loss tells you “this rendered image doesn’t look like your text prompt,” and you backpropagate through the renderer to update the Gaussian parameters. After a few thousand steps, the Gaussians arrange themselves into a coherent 3D scene that matches your prompt from every angle.

Install what you need:

1
pip install gsplat torch diffusers transformers accelerate plyfile numpy

gsplat is a fast CUDA-based Gaussian splatting renderer from Nerfstudio. It handles the differentiable rasterization so you don’t have to write custom CUDA kernels.

Initializing the 3D Gaussians

Start by creating a random point cloud. Each Gaussian gets a 3D position, a scale, a rotation quaternion, an opacity value, and RGB color. All of these are learnable parameters.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
import numpy as np

device = torch.device("cuda")

num_gaussians = 10000

# Random positions in a unit sphere
positions = torch.randn(num_gaussians, 3, device=device) * 0.5
# Log-scale (exponentiated later to keep scales positive)
log_scales = torch.full((num_gaussians, 3), -3.0, device=device)
# Identity quaternions with small noise
quats = torch.zeros(num_gaussians, 4, device=device)
quats[:, 0] = 1.0
quats += torch.randn_like(quats) * 0.01
quats = quats / quats.norm(dim=-1, keepdim=True)
# Sigmoid-mapped opacities (logit space)
opacities_logit = torch.full((num_gaussians, 1), 0.5, device=device)
# RGB colors in sigmoid space
colors_logit = torch.randn(num_gaussians, 3, device=device)

# Make everything a leaf tensor requiring grad
params = {
    "positions": nn.Parameter(positions),
    "log_scales": nn.Parameter(log_scales),
    "quats": nn.Parameter(quats),
    "opacities_logit": nn.Parameter(opacities_logit),
    "colors_logit": nn.Parameter(colors_logit),
}

The key design choice here: store scales in log-space and opacities/colors in logit-space. This way the optimizer works in unconstrained space while the actual values stay physically valid after applying exp and sigmoid.

Rendering with gsplat and Computing SDS Loss

This is where the pieces come together. You render from a random camera, pass the rendered image to Stable Diffusion, and compute the SDS gradient.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch.nn.functional as F
from gsplat import rasterization
from diffusers import StableDiffusionPipeline

# Load Stable Diffusion for SDS guidance
sd_pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
)
sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=True)

# Extract the components we need for SDS
vae = sd_pipe.vae
unet = sd_pipe.unet
text_encoder = sd_pipe.text_encoder
tokenizer = sd_pipe.tokenizer
scheduler = sd_pipe.scheduler

# Encode the text prompt once
prompt = "a fantasy castle on a floating island, detailed, 3D render"
tokens = tokenizer(prompt, return_tensors="pt", padding="max_length",
                   max_length=77, truncation=True).input_ids.to(device)
with torch.no_grad():
    text_embeddings = text_encoder(tokens)[0]
    uncond_tokens = tokenizer("", return_tensors="pt", padding="max_length",
                              max_length=77).input_ids.to(device)
    uncond_embeddings = text_encoder(uncond_tokens)[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])


def random_camera(radius=2.0):
    """Generate a random camera looking at the origin."""
    theta = torch.rand(1).item() * 2 * np.pi
    phi = torch.rand(1).item() * np.pi * 0.6 + 0.2  # avoid poles
    x = radius * np.sin(phi) * np.cos(theta)
    y = radius * np.sin(phi) * np.sin(theta)
    z = radius * np.cos(phi)
    cam_pos = torch.tensor([x, y, z], dtype=torch.float32)
    forward = -cam_pos / cam_pos.norm()
    world_up = torch.tensor([0.0, 0.0, 1.0])
    right = torch.cross(forward, world_up)
    right = right / right.norm()
    up = torch.cross(right, forward)
    # Build 4x4 view matrix (world-to-camera)
    viewmat = torch.eye(4, dtype=torch.float32)
    viewmat[:3, 0] = right
    viewmat[:3, 1] = up
    viewmat[:3, 2] = -forward
    viewmat[:3, 3] = cam_pos
    viewmat = torch.inverse(viewmat)
    return viewmat.to(device)


def render_gaussians(params, viewmat, img_size=256):
    """Render the Gaussian scene from the given viewpoint."""
    scales = params["log_scales"].exp()
    opacities = torch.sigmoid(params["opacities_logit"]).squeeze(-1)
    colors = torch.sigmoid(params["colors_logit"])

    fx = fy = img_size * 1.2
    cx, cy = img_size / 2.0, img_size / 2.0
    Ks = torch.tensor([[fx, 0, cx],
                        [0, fy, cy],
                        [0, 0, 1]], dtype=torch.float32, device=device)

    renders, alphas, meta = rasterization(
        means=params["positions"],
        quats=params["quats"] / params["quats"].norm(dim=-1, keepdim=True),
        scales=scales,
        opacities=opacities,
        colors=colors,
        viewmats=viewmat.unsqueeze(0),
        Ks=Ks.unsqueeze(0),
        width=img_size,
        height=img_size,
        packed=False,
    )
    # renders shape: (1, H, W, 3)
    return renders[0]  # (H, W, 3)

The rasterization function from gsplat is fully differentiable. Gradients flow from the rendered pixels all the way back to Gaussian positions, scales, rotations, and colors.

The SDS Optimization Loop

Now wire up the training loop. For each iteration: render from a random angle, encode into latent space, add noise at a random timestep, predict the noise with the UNet, and compute the SDS gradient.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
optimizer = torch.optim.Adam(params.values(), lr=0.005)
scheduler.set_timesteps(1000)
guidance_scale = 100.0
num_steps = 3000
img_size = 256

for step in range(num_steps):
    optimizer.zero_grad()

    viewmat = random_camera(radius=2.0)
    rendered = render_gaussians(params, viewmat, img_size=img_size)

    # Prepare image for VAE: (H,W,3) -> (1,3,H,W), scale to [-1,1]
    img = rendered.permute(2, 0, 1).unsqueeze(0).clamp(0, 1)
    img_512 = F.interpolate(img, size=(512, 512), mode="bilinear",
                            align_corners=False)
    img_sd = img_512 * 2.0 - 1.0

    # Encode to latent space
    with torch.no_grad():
        latents = vae.encode(img_sd.half()).latent_dist.sample() * 0.18215

    # Random timestep
    t = torch.randint(20, 980, (1,), device=device).long()
    noise = torch.randn_like(latents)
    noisy_latents = scheduler.add_noise(latents, noise, t)

    # Predict noise with classifier-free guidance
    latent_input = torch.cat([noisy_latents] * 2)
    t_input = torch.cat([t] * 2)
    with torch.no_grad():
        noise_pred = unet(latent_input, t_input,
                          encoder_hidden_states=text_embeddings).sample
    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + guidance_scale * (
        noise_pred_text - noise_pred_uncond
    )

    # SDS gradient: difference between predicted noise and actual noise
    # We backprop this through the rendered image -> Gaussian params
    w = (1 - scheduler.alphas_cumprod[t]).to(latents.dtype)
    grad = w * (noise_pred - noise)

    # Backprop the SDS gradient through the VAE decoder and renderer
    latents.backward(gradient=grad)
    optimizer.step()

    if step % 500 == 0:
        print(f"Step {step}/{num_steps}")

A few things to note: the guidance_scale for SDS is much higher than normal image generation (100 vs 7.5). This is because SDS needs a strong signal to push the 3D representation. The timestep range of 20-980 avoids the noisy extremes where the gradient signal is either too weak or too chaotic.

Exporting to PLY Format

Once optimization finishes, save the Gaussians as a PLY file. This format is widely supported by viewers like SuperSplat and the web-based splat viewer.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from plyfile import PlyData, PlyElement

def export_gaussians_to_ply(params, path="scene.ply"):
    """Export Gaussian parameters to a PLY file."""
    pos = params["positions"].detach().cpu().numpy()
    scales = params["log_scales"].exp().detach().cpu().numpy()
    quats_raw = params["quats"].detach()
    quats_norm = (quats_raw / quats_raw.norm(dim=-1, keepdim=True)).cpu().numpy()
    opacities = torch.sigmoid(params["opacities_logit"]).detach().cpu().numpy()
    colors = (torch.sigmoid(params["colors_logit"]) * 255).detach().cpu().numpy().astype(np.uint8)

    dtype = [
        ("x", "f4"), ("y", "f4"), ("z", "f4"),
        ("scale_0", "f4"), ("scale_1", "f4"), ("scale_2", "f4"),
        ("rot_0", "f4"), ("rot_1", "f4"), ("rot_2", "f4"), ("rot_3", "f4"),
        ("opacity", "f4"),
        ("red", "u1"), ("green", "u1"), ("blue", "u1"),
    ]

    elements = np.empty(len(pos), dtype=dtype)
    elements["x"] = pos[:, 0]
    elements["y"] = pos[:, 1]
    elements["z"] = pos[:, 2]
    elements["scale_0"] = scales[:, 0]
    elements["scale_1"] = scales[:, 1]
    elements["scale_2"] = scales[:, 2]
    elements["rot_0"] = quats_norm[:, 0]
    elements["rot_1"] = quats_norm[:, 1]
    elements["rot_2"] = quats_norm[:, 2]
    elements["rot_3"] = quats_norm[:, 3]
    elements["opacity"] = opacities.squeeze()
    elements["red"] = colors[:, 0]
    elements["green"] = colors[:, 1]
    elements["blue"] = colors[:, 2]

    el = PlyElement.describe(elements, "vertex")
    PlyData([el]).write(path)
    print(f"Exported {len(pos)} Gaussians to {path}")

export_gaussians_to_ply(params, "fantasy_castle.ply")

You can view the resulting PLY in SuperSplat by dragging the file into the browser. For programmatic rendering from specific angles, use the same render_gaussians function from the training loop with a fixed camera matrix and save the output with torchvision.utils.save_image.

Common Errors and Fixes

gsplat fails to install or import: This happens when your CUDA toolkit version doesn’t match PyTorch’s. Check both:

1
2
nvcc --version
python -c "import torch; print(torch.version.cuda)"

If they mismatch, reinstall PyTorch for your local CUDA version from pytorch.org. gsplat compiles custom CUDA kernels at install time, so the versions must align.

Out of memory during SDS optimization: The VAE encoder and UNet eat VRAM on top of your Gaussians. Reduce num_gaussians to 5000 or drop img_size to 128 (the image gets upscaled to 512 for Stable Diffusion anyway). You can also use torch.cuda.amp.autocast around the rendering step for mixed-precision, but keep the SDS gradient computation in float16 since the UNet already runs in half precision.

Scene collapses to a flat blob: This usually means your guidance scale is too low or your learning rate is too high. Try guidance_scale=150 and lr=0.001. Also make sure you’re sampling camera positions from a full hemisphere – if all cameras point from the same direction, the optimizer only has incentive to make one side look good.

Colors look washed out: The SDS loss can push colors toward the mean. Adding a small regularization term that penalizes low color saturation helps. Alternatively, anneal the timestep range during training – start with large timesteps (coarse structure) and shift to smaller timesteps (fine detail) over the course of optimization.