Spin Up Airflow with Docker Compose#
Skip the manual installation. The official Docker Compose setup gets you a working Airflow instance in under two minutes with a scheduler, webserver, and Postgres backend:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
| mkdir airflow-ml && cd airflow-ml
mkdir -p dags logs plugins data
# Grab the official docker-compose file
curl -LfO "https://airflow.apache.org/docs/apache-airflow/2.9.0/docker-compose.yaml"
# Set the Airflow user to your host UID so file permissions stay sane
echo -e "AIRFLOW_UID=$(id -u)" > .env
# Initialize the database and create the admin user
docker compose up airflow-init
# Start everything
docker compose up -d
|
The webserver runs at http://localhost:8080. Default credentials are airflow / airflow. Change these immediately if you expose this to anything beyond localhost.
Define Your First ML Data DAG#
A DAG (Directed Acyclic Graph) is how Airflow represents a pipeline. Each node is a task, and edges define dependencies. For an ML ETL pipeline, you typically need four stages: extract raw data, transform and clean it, validate the output, then load it where your training code can reach it.
Drop this file into dags/ml_etl_pipeline.py:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
| from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.operators.bash import BashOperator
default_args = {
"owner": "ml-team",
"retries": 2,
"retry_delay": timedelta(minutes=5),
"email_on_failure": True,
"email": ["[email protected]"],
}
dag = DAG(
dag_id="ml_data_etl",
default_args=default_args,
description="Extract, transform, validate, and load ML training data",
schedule="0 6 * * *", # Daily at 6 AM UTC
start_date=datetime(2026, 2, 1),
catchup=False,
tags=["ml", "etl", "data"],
)
def extract(**context):
"""Pull raw data from source systems."""
import pandas as pd
from sqlalchemy import create_engine
engine = create_engine("postgresql://user:pass@db-host:5432/production")
query = """
SELECT id, text, label, created_at
FROM user_feedback
WHERE created_at >= NOW() - INTERVAL '1 day'
"""
df = pd.read_sql(query, engine)
output_path = "/opt/airflow/data/raw_extract.parquet"
df.to_parquet(output_path, index=False)
# Push metadata to XCom for downstream tasks
context["ti"].xcom_push(key="row_count", value=len(df))
context["ti"].xcom_push(key="extract_path", value=output_path)
print(f"Extracted {len(df)} rows")
def transform(**context):
"""Clean and feature-engineer the extracted data."""
import pandas as pd
ti = context["ti"]
input_path = ti.xcom_pull(task_ids="extract_task", key="extract_path")
df = pd.read_parquet(input_path)
# Drop exact duplicates
before = len(df)
df = df.drop_duplicates(subset=["text"], keep="first")
print(f"Removed {before - len(df)} duplicate rows")
# Normalize text
df["text"] = df["text"].str.strip().str.lower()
# Drop rows with empty text or missing labels
df = df.dropna(subset=["text", "label"])
df = df[df["text"].str.len() > 10]
# Add engineered features
df["text_length"] = df["text"].str.len()
df["word_count"] = df["text"].str.split().str.len()
output_path = "/opt/airflow/data/transformed.parquet"
df.to_parquet(output_path, index=False)
ti.xcom_push(key="transform_path", value=output_path)
ti.xcom_push(key="clean_row_count", value=len(df))
def validate(**context):
"""Run data quality checks before loading."""
import pandas as pd
ti = context["ti"]
input_path = ti.xcom_pull(task_ids="transform_task", key="transform_path")
clean_count = ti.xcom_pull(task_ids="transform_task", key="clean_row_count")
raw_count = ti.xcom_pull(task_ids="extract_task", key="row_count")
df = pd.read_parquet(input_path)
errors = []
# Check minimum row count -- don't train on tiny datasets
if len(df) < 100:
errors.append(f"Only {len(df)} rows after cleaning. Minimum is 100.")
# Check for null values in critical columns
null_counts = df[["text", "label"]].isnull().sum()
if null_counts.sum() > 0:
errors.append(f"Null values found: {null_counts.to_dict()}")
# Check data loss ratio -- if transform dropped more than 50%, something is wrong
drop_ratio = 1 - (clean_count / raw_count) if raw_count > 0 else 1
if drop_ratio > 0.5:
errors.append(f"Transform dropped {drop_ratio:.0%} of rows. Investigate source data.")
# Check label distribution for severe imbalance
label_dist = df["label"].value_counts(normalize=True)
if label_dist.min() < 0.05:
errors.append(f"Severe label imbalance: {label_dist.to_dict()}")
if errors:
raise ValueError(f"Validation failed:\n" + "\n".join(errors))
print(f"Validation passed: {len(df)} rows, {df['label'].nunique()} labels")
def load(**context):
"""Move validated data to the training data store."""
import shutil
from datetime import datetime
ti = context["ti"]
source = ti.xcom_pull(task_ids="transform_task", key="transform_path")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
dest = f"/opt/airflow/data/training/dataset_{timestamp}.parquet"
shutil.copy2(source, dest)
ti.xcom_push(key="training_path", value=dest)
print(f"Loaded dataset to {dest}")
extract_task = PythonOperator(
task_id="extract_task", python_callable=extract, dag=dag
)
transform_task = PythonOperator(
task_id="transform_task", python_callable=transform, dag=dag
)
validate_task = PythonOperator(
task_id="validate_task", python_callable=validate, dag=dag
)
load_task = PythonOperator(
task_id="load_task", python_callable=load, dag=dag
)
retrain_task = BashOperator(
task_id="retrain_trigger",
bash_command="python /opt/airflow/scripts/trigger_retrain.py",
dag=dag,
)
extract_task >> transform_task >> validate_task >> load_task >> retrain_task
|
The >> operator sets task dependencies. Extract runs first, then transform, then validate, then load, then retrain. If validation fails, the downstream tasks never execute – exactly what you want.
Scheduling Strategies That Actually Work#
The schedule parameter accepts cron expressions. Pick your schedule based on how your source data arrives, not how often you wish you had fresh data.
"0 6 * * *" – Daily at 6 AM. Good default for most production ML systems."0 */4 * * *" – Every 4 hours. Use when your model serves real-time predictions and data drift matters."0 2 * * 0" – Weekly on Sunday at 2 AM. Fine for models that do not degrade quickly.
Set catchup=False unless you specifically need Airflow to backfill every missed run. With catchup=True, deploying a DAG with a start_date months ago creates a storm of backfill runs that will choke your database connections and flood your logs.
Add Data Validation with Great Expectations#
The inline validation above catches the basics. For production, wire in Great Expectations for schema enforcement and statistical checks. Install it in your Airflow Docker image:
1
2
| FROM apache/airflow:2.9.0-python3.11
RUN pip install great-expectations==0.18.0 pandas sqlalchemy
|
Then replace the validate function with a GE-powered check:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
| def validate_with_ge(**context):
"""Validate data using Great Expectations."""
import great_expectations as gx
import pandas as pd
ti = context["ti"]
input_path = ti.xcom_pull(task_ids="transform_task", key="transform_path")
df = pd.read_parquet(input_path)
context_gx = gx.get_context()
datasource = context_gx.sources.add_pandas("ml_data")
asset = datasource.add_dataframe_asset("training_batch")
batch = asset.build_batch_request(dataframe=df)
validator = context_gx.get_validator(batch_request=batch)
# Schema checks
validator.expect_column_values_to_not_be_null("text")
validator.expect_column_values_to_not_be_null("label")
validator.expect_column_values_to_be_of_type("text_length", "int64")
# Statistical checks
validator.expect_column_values_to_be_between("word_count", min_value=2, max_value=5000)
validator.expect_column_unique_value_count_to_be_between("label", min_value=2, max_value=100)
results = validator.validate()
if not results["success"]:
failed = [r for r in results["results"] if not r["success"]]
raise ValueError(f"GE validation failed: {failed}")
|
This is better than hand-rolled checks because GE tracks validation results over time, so you can spot slow data quality degradation across runs.
Triggering Model Retraining#
The retrain_trigger task in the DAG calls a script that kicks off training. Use BashOperator or PythonOperator depending on your training setup. For teams running training on Kubernetes or a cloud ML service, an HTTP trigger is cleaner:
1
2
3
4
5
6
7
8
9
10
11
| from airflow.providers.http.operators.http import SimpleHttpOperator
retrain_task = SimpleHttpOperator(
task_id="retrain_trigger",
http_conn_id="ml_training_api",
endpoint="/api/v1/train",
method="POST",
data='{"dataset": "{{ ti.xcom_pull(task_ids=\'load_task\', key=\'training_path\') }}"}',
headers={"Content-Type": "application/json"},
dag=dag,
)
|
Set up the ml_training_api connection in Airflow’s Admin > Connections UI, pointing to your training service endpoint. This keeps credentials out of your DAG code.
My strong recommendation: do not trigger retraining on every single ETL run. Add a conditional check – only retrain when you have accumulated enough new data or when validation metrics show drift. A BranchPythonOperator handles this well:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
| from airflow.operators.python import BranchPythonOperator
def should_retrain(**context):
ti = context["ti"]
new_rows = ti.xcom_pull(task_ids="transform_task", key="clean_row_count")
if new_rows >= 1000:
return "retrain_trigger"
return "skip_retrain"
branch_task = BranchPythonOperator(
task_id="check_retrain",
python_callable=should_retrain,
dag=dag,
)
|
Common Errors#
airflow.exceptions.DagNotFound after adding a new DAG file.
Airflow scans the dags/ folder periodically (default 30 seconds). If you see this right after dropping in a file, wait a minute. If it persists, check for Python syntax errors – Airflow silently skips DAGs that fail to import. Test with:
1
| docker compose exec airflow-webserver python /opt/airflow/dags/ml_etl_pipeline.py
|
sqlalchemy.exc.OperationalError: could not connect to server inside extract.
Your Airflow containers run on a Docker network. localhost from inside the container does not point to your host machine. Use the actual hostname of your database, or for a host-running database, use host.docker.internal on Docker Desktop or 172.17.0.1 on Linux.
TypeError: Object of type int64 is not JSON serializable when using XCom.
XCom serializes values to JSON by default. NumPy int64 and float64 types break this. Cast explicitly before pushing:
1
| context["ti"].xcom_push(key="row_count", value=int(len(df)))
|
Task received SIGTERM and tasks getting killed mid-run.
The default Docker Compose allocates limited resources. Large transforms will get OOM-killed. Increase worker memory in your docker-compose.yaml:
1
2
3
| environment:
AIRFLOW__CORE__WORKER_CONCURRENCY: 2
AIRFLOW__CELERY__WORKER_CONCURRENCY: 2
|
Or switch to LocalExecutor for single-machine setups – it avoids Celery overhead entirely.
Tasks stuck in queued state and never running.
Usually means the scheduler is running but no workers are up. Check docker compose ps and confirm the airflow-worker container is healthy. On LocalExecutor, this happens when parallelism in airflow.cfg is set to 0.