Quick Start
ARC (Autonomous Recovery Controller) makes neural network training self-healing. Install it, wrap your training loop, and forget about NaN explosions, gradient collapse, and wasted compute.
ARC can be added to any PyTorch training loop with just 3 additional lines. No architectural changes required.
Installation
From PyPI
pip install arc-trainingFrom Source
git clone https://github.com/a-kaushik2209/ARC.git
cd ARC
pip install -e .Optional Dependencies
pip install arc-training[lightning]
pip install arc-training[full]
pip install arc-training[dev]Requirements: Python 3.8+, PyTorch 1.9+, NumPy 1.19+
First Integration
import torch
import torch.nn as nn
from arc import ArcV2
model = nn.Sequential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
arc = ArcV2.auto(model, optimizer)
for epoch in range(100):
for batch_x, batch_y in dataloader:
optimizer.zero_grad()
output = model(batch_x)
loss = nn.functional.cross_entropy(output, batch_y)
loss.backward()
optimizer.step()
arc.step(loss)
arc.end_epoch(epoch)On each arc.step(loss) call, ARC collects gradient norms, weight statistics, loss trajectory, and optimizer state. When it detects anomalies, it automatically rolls back to the last healthy checkpoint and adjusts hyperparameters.
How ARC Works
ARC operates as a three-stage pipeline that runs alongside your training loop:
- Signal Collection - Gradient norms, activation statistics, weight health, optimizer state, loss curvature, and spectral features are collected every epoch.
- Feature Extraction - Raw signals are transformed into 12 engineered features: trends, variances, accelerations, correlations, and entropy measures.
- Prediction and Intervention - An MLP classifier evaluates failure risk. If risk is critical, ARC rolls back to the last checkpoint and applies corrective measures.
Signal Monitoring
ARC tracks 12+ real-time signals across your model:
| Signal | Collector | What It Detects |
|---|---|---|
| Gradient norms and entropy | GradientCollector | Vanishing/exploding gradients |
| Activation statistics | ActivationCollector | Representation collapse |
| Weight norms and NaN check | WeightCollector | Weight corruption, dead neurons |
| Optimizer state integrity | OptimizerCollector | Silent momentum failures |
| Loss trajectory and curvature | LossCollector | Divergence, NaN loss |
| Hessian proxy | CurvatureCollector | Loss landscape instability |
Failure Prediction
ARC's MLP classifier achieves 97.5% accuracy and 100% precision (zero false positives) across 200 test scenarios covering 5 failure modes:
- Divergence - Loss explosion, NaN/Inf values
- Vanishing Gradients - Gradient norms collapse below threshold
- Exploding Gradients - Gradient norms exceed safe bounds
- Representation Collapse - All activations become identical
- Severe Overfitting - Train/val gap exceeds threshold
Recovery and Rollback
When a failure is detected, ARC automatically:
- Restores model weights from the last healthy checkpoint
- Restores optimizer state (including momentum buffers)
- Reduces learning rate by a configurable factor
- Optionally enables gradient clipping
- Logs the failure event for post-hoc analysis
Configuration
from arc import Config
config = Config()
config.signal.activation_sample_ratio = 0.2
config.signal.compute_curvature_proxy = True
config.prediction.high_risk_threshold = 0.7
config.prediction.mc_dropout_samples = 30
config.thresholds.loss_explosion_factor = 10.0
config.thresholds.vanishing_grad_threshold = 1e-7
config.overhead.max_overhead_percent = 5.0Configuration Presets
from arc import Config
config = Config.low_overhead()
config = Config.high_accuracy()
json_str = config.to_json()
config = Config.from_json(json_str)Failure Thresholds
| Parameter | Default | Description |
|---|---|---|
loss_explosion_factor | 10.0 | Loss must increase by this factor to flag explosion |
vanishing_grad_threshold | 1e-7 | Gradient norm below this triggers vanishing alert |
exploding_grad_threshold | 1e4 | Gradient norm above this triggers exploding alert |
activation_similarity_threshold | 0.95 | Cosine similarity above this signals collapse |
overfit_gap_threshold | 0.5 | Train/val loss gap above this flags overfitting |
PINN Stabilizer
from arc import PINNStabilizer
stabilizer = PINNStabilizer(model, n_loss_terms=3)
for epoch in range(1000):
pde_loss, bc_loss, ic_loss = compute_losses(model, inputs)
total = stabilizer.get_stabilized_loss([pde_loss, bc_loss, ic_loss])
optimizer.zero_grad()
total.backward()
stabilizer.stabilize_step()
optimizer.step()
stabilizer.update(epoch, total)Continual Learning (EWC)
arc = ArcV2.auto(model, optimizer)
arc.begin_task("mnist")
train(model, mnist_loader, arc)
arc.consolidate_task(mnist_loader)
arc.begin_task("fashion_mnist")
for epoch in range(10):
for batch in fashion_loader:
optimizer.zero_grad()
loss = compute_loss(model, batch) + arc.get_ewc_loss()
loss.backward()
optimizer.step()
arc.step(loss)
arc.end_epoch(epoch)Uncertainty Quantification
from arc import ConformalPredictor
cp = ConformalPredictor(model, alpha=0.1)
cp.calibrate(calibration_loader)
result = cp.predict(test_input)
print(result.prediction, result.confidence, result.set_members)Adversarial Defense
from arc import AdversarialDetector
detector = AdversarialDetector(model)
detector.fit(clean_data_loader)
alert = detector.detect(suspicious_input)
print(alert.is_adversarial, alert.confidence)PyTorch Lightning
from arc import ArcCallback
import pytorch_lightning as pl
trainer = pl.Trainer(callbacks=[ArcCallback()])Docker
FROM pytorch/pytorch:2.1.0-cuda11.8-runtime
RUN pip install arc-training
COPY . /app
WORKDIR /app
CMD ["python", "train.py"]Changelog
v4.2.2 - March 2026
- 97.5% prediction accuracy (up from 86.5%) with 12-feature MLP
- Every paper claim backed by experiment scripts
- Real overhead measurement: less than 10% for models above 250K params
- Honest ablation table from real experiment data
v4.0.0 - January 2026
- 100% recovery on all numeric failure types
- OOM recovery across all training stages
- Quantized FP16 checkpointing with 50% memory reduction
- Silent failure detection (accuracy collapse, dead neurons)
v3.0.0 - December 2025
- Multi-architecture support (CNN, ViT, Transformer, Diffusion UNet)
- Scaling validated to 387M parameters