Files
soroban-abacus-flashcards/apps/web/scripts/train-boundary-detector/train_model.py
Thomas Hallock 269321b4c4 feat(flowchart): add animated background tiles to FlowchartCards
- Add AnimatedProblemTile component with MathDisplay for proper math rendering
- Add AnimatedBackgroundTiles grid component for card backgrounds
- Update FlowchartCard to accept flowchart + examples props
- Generate examples client-side for both hardcoded and database flowcharts
- Use same formatting system (formatProblemDisplay + MathDisplay) as modal

Also includes:
- Fix migration 0076 timestamp ordering issue (linkedPublishedId column)
- Add migration-timestamp-fix skill documenting common drizzle-kit issue
- Update CLAUDE.md with migration timestamp ordering guidance
- Various flowchart workshop and vision training improvements

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 10:10:49 -06:00

1545 lines
53 KiB
Python

#!/usr/bin/env python3
"""
Train a neural network to detect abacus boundary corners using heatmap regression.
Architecture:
- MobileNetV2 backbone (pretrained ImageNet)
- BiFPN-style feature fusion neck
- 4 heatmap outputs (one per corner: TL, TR, BL, BR)
- DSNT layer converts heatmaps to coordinates (differentiable)
Loss:
- Adaptive Wing Loss on heatmaps
- Smooth L1 on coordinates
- Convexity regularization
This approach outperforms direct coordinate regression for localization tasks.
Usage:
python scripts/train-boundary-detector/train_model.py [options]
Options:
--data-dir DIR Training data directory (default: ./data/vision-training/boundary-frames)
--output-dir DIR Output directory for model (default: ./public/models/abacus-boundary-detector)
--epochs N Number of training epochs (default: 100)
--batch-size N Batch size (default: 16)
--validation-split Validation split ratio (default: 0.2)
--json-progress Output JSON progress for streaming to web UI
"""
import argparse
import base64
import io
import json
import os
import platform
import socket
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Tuple, List
import numpy as np
# Import marker masking (from same directory)
from marker_masking import mask_markers
def get_hardware_info() -> dict:
"""Detect available hardware for TensorFlow training."""
result = {
"device": "unknown",
"deviceName": "Unknown",
"deviceType": "unknown",
"tensorflowVersion": None,
"platform": platform.system(),
"machine": platform.machine(),
"processor": platform.processor(),
}
try:
import tensorflow as tf
result["tensorflowVersion"] = tf.__version__
gpus = tf.config.list_physical_devices('GPU')
is_apple_silicon = (
platform.system() == "Darwin" and
platform.machine() == "arm64"
)
if gpus:
result["deviceType"] = "gpu"
if is_apple_silicon:
result["device"] = "metal"
result["deviceName"] = "Apple Silicon GPU (Metal)"
# Try to get chip info
try:
import subprocess
sp_output = subprocess.check_output(
["system_profiler", "SPHardwareDataType", "-json"],
text=True,
timeout=5
)
sp_data = json.loads(sp_output)
hardware = sp_data.get("SPHardwareDataType", [{}])[0]
chip_type = hardware.get("chip_type", "")
if chip_type:
result["deviceName"] = f"{chip_type} GPU (Metal)"
result["chipType"] = chip_type
memory = hardware.get("physical_memory", "")
if memory:
result["systemMemory"] = memory
except Exception:
pass
else:
result["device"] = "cuda"
result["deviceName"] = gpus[0].name
result["gpuCount"] = len(gpus)
else:
result["device"] = "cpu"
result["deviceType"] = "cpu"
result["deviceName"] = "CPU"
except ImportError:
result["error"] = "TensorFlow not installed"
except Exception as e:
result["error"] = str(e)
return result
def get_environment_info() -> dict:
"""Get information about the training environment."""
return {
"hostname": socket.gethostname(),
"username": os.environ.get("USER", os.environ.get("USERNAME", "unknown")),
"pythonVersion": platform.python_version(),
"workingDirectory": os.getcwd(),
"platform": platform.system(),
"platformVersion": platform.version(),
"architecture": platform.machine(),
}
def emit_progress(event_type: str, data: dict, use_json: bool = False):
"""Emit a progress event, either as text or JSON."""
if use_json:
print(json.dumps({"event": event_type, **data}), flush=True)
else:
if event_type == "status":
print(data.get("message", ""))
elif event_type == "epoch":
print(
f"Epoch {data['epoch']}/{data['total_epochs']} - "
f"loss: {data['loss']:.4f} - val_loss: {data['val_loss']:.4f} - "
f"coord_mae: {data.get('coord_mae', 0):.4f}"
)
elif event_type == "complete":
print(f"\nTraining complete! Final coord MAE: {data['final_mae']:.4f}")
elif event_type == "error":
print(f"Error: {data.get('message', 'Unknown error')}")
def generate_inference_samples(model, X_val, coords_val, num_samples: int = 5, image_size: int = 224):
"""
Generate inference samples for visualization.
Args:
model: Trained model
X_val: Validation images (N, H, W, 3) with values 0-1
coords_val: Ground truth coordinates (N, 4, 2) normalized 0-1
num_samples: Number of samples to generate
image_size: Image size in pixels (for pixel error calculation)
Returns:
List of sample dicts with imageBase64, predicted, groundTruth, pixelError
"""
import tensorflow as tf
from PIL import Image
# Select random indices (use consistent seed based on epoch for reproducibility)
num_val = len(X_val)
if num_val < num_samples:
indices = list(range(num_val))
else:
indices = np.random.choice(num_val, size=num_samples, replace=False)
samples = []
for idx in indices:
# Get image and ground truth
image = X_val[idx] # Shape: (224, 224, 3), values 0-1
gt_coords = coords_val[idx] # Shape: (4, 2), normalized 0-1
# Run inference
image_batch = np.expand_dims(image, axis=0) # (1, 224, 224, 3)
heatmaps = model(image_batch, training=False) # (1, H, W, 4)
# Decode heatmaps to coordinates using DSNT
# Inline simplified DSNT decode for single sample
hm = heatmaps[0].numpy() # (H, W, 4)
h, w, num_kp = hm.shape
pred_coords = []
for kp in range(num_kp):
hm_kp = hm[:, :, kp]
# Softmax normalization
hm_kp = hm_kp - hm_kp.max()
hm_kp = np.exp(hm_kp)
hm_kp = hm_kp / (hm_kp.sum() + 1e-8)
# Compute expected x, y coordinates
x_range = np.linspace(0, 1, w)
y_range = np.linspace(0, 1, h)
expected_x = np.sum(hm_kp * x_range.reshape(1, -1))
expected_y = np.sum(hm_kp * y_range.reshape(-1, 1))
pred_coords.append([expected_x, expected_y])
pred_coords = np.array(pred_coords) # (4, 2)
# Calculate pixel error for this sample
pixel_errors = np.sqrt(np.sum((pred_coords - gt_coords) ** 2, axis=1)) * image_size
mean_pixel_error = float(np.mean(pixel_errors))
# Encode image as base64 JPEG
img_uint8 = (image * 255).astype(np.uint8)
pil_img = Image.fromarray(img_uint8)
buffer = io.BytesIO()
pil_img.save(buffer, format='JPEG', quality=80)
img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
# Flatten coordinates to list: [tl_x, tl_y, tr_x, tr_y, br_x, br_y, bl_x, bl_y]
# Model outputs: TL, TR, BL, BR (need to reorder to TL, TR, BR, BL for consistency)
# Actually let's keep as TL, TR, BL, BR and handle in UI
pred_flat = pred_coords.flatten().tolist()
gt_flat = gt_coords.flatten().tolist()
samples.append({
"imageBase64": img_base64,
"predicted": pred_flat,
"groundTruth": gt_flat,
"pixelError": mean_pixel_error
})
return samples
def parse_args():
parser = argparse.ArgumentParser(description="Train abacus boundary detector (heatmap + DSNT)")
parser.add_argument(
"--data-dir",
type=str,
default="./data/vision-training/boundary-frames",
help="Training data directory",
)
parser.add_argument(
"--output-dir",
type=str,
default="./public/models/abacus-boundary-detector",
help="Output directory for model",
)
parser.add_argument(
"--epochs", type=int, default=100, help="Number of training epochs"
)
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
parser.add_argument(
"--validation-split",
type=float,
default=0.2,
help="Validation split ratio",
)
parser.add_argument(
"--heatmap-size",
type=int,
default=56,
help="Size of output heatmaps (default: 56 for 224/4)",
)
parser.add_argument(
"--heatmap-sigma",
type=float,
default=2.0,
help="Gaussian sigma for heatmap generation",
)
parser.add_argument(
"--json-progress",
action="store_true",
help="Output JSON progress for streaming to web UI",
)
parser.add_argument(
"--no-augmentation",
action="store_true",
help="Disable data augmentation (ignored for boundary detector)",
)
parser.add_argument(
"--color-augmentation",
action="store_true",
help="Enable color-only augmentation (brightness, contrast, saturation variations)",
)
parser.add_argument(
"--no-marker-masking",
action="store_true",
help="Disable marker masking (for comparison testing - not recommended for production)",
)
parser.add_argument(
"--stop-file",
type=str,
default=None,
help="Path to a file that, when created, signals training to stop and save",
)
parser.add_argument(
"--session-id",
type=str,
default=None,
help="Session ID for tracking this training run (for session management)",
)
parser.add_argument(
"--manifest-file",
type=str,
default=None,
help="Path to manifest JSON listing specific items to train on (for filtered training).",
)
return parser.parse_args()
# =============================================================================
# Image Preprocessing
# =============================================================================
def resize_sample(image: np.ndarray, keypoints: List[Tuple[float, float]],
image_size: int = 224) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
"""
Resize image to target size. Keypoints are already normalized so they don't change.
Args:
image: Input image (H, W, 3) uint8
keypoints: List of 4 (x, y) normalized coordinates (0-1)
image_size: Target size
Returns:
Resized image and unchanged keypoints
"""
from PIL import Image as PILImage
img = PILImage.fromarray(image.astype(np.uint8))
img = img.resize((image_size, image_size), PILImage.BILINEAR)
# Keypoints are normalized (0-1) so they don't change with resize
return np.array(img), keypoints
# =============================================================================
# Color Augmentation (No geometric transforms - those break corner labels)
# =============================================================================
def augment_brightness(image: np.ndarray, factor: float) -> np.ndarray:
"""
Adjust image brightness.
Args:
image: Input image (H, W, 3) uint8
factor: Brightness factor (1.0 = no change, >1 = brighter, <1 = darker)
Returns:
Brightness-adjusted image (H, W, 3) uint8
"""
return np.clip(image.astype(np.float32) * factor, 0, 255).astype(np.uint8)
def augment_contrast(image: np.ndarray, factor: float) -> np.ndarray:
"""
Adjust image contrast.
Args:
image: Input image (H, W, 3) uint8
factor: Contrast factor (1.0 = no change, >1 = more contrast, <1 = less contrast)
Returns:
Contrast-adjusted image (H, W, 3) uint8
"""
mean = np.mean(image, axis=(0, 1), keepdims=True)
return np.clip((image.astype(np.float32) - mean) * factor + mean, 0, 255).astype(np.uint8)
def augment_saturation(image: np.ndarray, factor: float) -> np.ndarray:
"""
Adjust image saturation.
Args:
image: Input image (H, W, 3) uint8 RGB
factor: Saturation factor (1.0 = no change, >1 = more saturated, <1 = less/grayscale)
Returns:
Saturation-adjusted image (H, W, 3) uint8
"""
# Convert to grayscale for desaturation reference
gray = np.mean(image, axis=2, keepdims=True)
return np.clip(gray + (image.astype(np.float32) - gray) * factor, 0, 255).astype(np.uint8)
def apply_color_augmentation(image: np.ndarray) -> List[np.ndarray]:
"""
Apply color augmentation variations to an image.
Returns multiple augmented versions (including original) with random
brightness, contrast, and saturation adjustments.
Args:
image: Input image (H, W, 3) uint8
Returns:
List of augmented images including the original
"""
augmented = [image] # Always include original
# Random brightness variations
for _ in range(2):
factor = np.random.uniform(0.7, 1.3)
augmented.append(augment_brightness(image, factor))
# Random contrast variations
for _ in range(2):
factor = np.random.uniform(0.7, 1.3)
augmented.append(augment_contrast(image, factor))
# Random saturation variations
for _ in range(2):
factor = np.random.uniform(0.5, 1.5)
augmented.append(augment_saturation(image, factor))
# Combined random adjustments
for _ in range(2):
img = image.copy()
img = augment_brightness(img, np.random.uniform(0.8, 1.2))
img = augment_contrast(img, np.random.uniform(0.8, 1.2))
img = augment_saturation(img, np.random.uniform(0.7, 1.3))
augmented.append(img)
return augmented
# =============================================================================
# Heatmap Generation
# =============================================================================
def generate_heatmap(point: Tuple[float, float], size: int, sigma: float) -> np.ndarray:
"""
Generate a Gaussian heatmap for a single keypoint.
Args:
point: (x, y) normalized coordinates (0-1)
size: Output heatmap size
sigma: Gaussian standard deviation
Returns:
Heatmap array of shape (size, size)
"""
x, y = point
# Create coordinate grids
xx, yy = np.meshgrid(np.arange(size), np.arange(size))
# Convert normalized coords to heatmap coords
cx = x * size
cy = y * size
# Generate Gaussian
heatmap = np.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2 * sigma ** 2))
return heatmap.astype(np.float32)
def generate_heatmaps(keypoints: List[Tuple[float, float]], size: int, sigma: float) -> np.ndarray:
"""
Generate heatmaps for all 4 corners.
Args:
keypoints: List of 4 (x, y) normalized coordinates
size: Output heatmap size
sigma: Gaussian standard deviation
Returns:
Array of shape (size, size, 4)
"""
heatmaps = np.stack([
generate_heatmap(kp, size, sigma) for kp in keypoints
], axis=-1)
return heatmaps
# =============================================================================
# Dataset Loading
# =============================================================================
def load_dataset(data_dir: str, image_size: int = 224, heatmap_size: int = 56,
heatmap_sigma: float = 2.0, use_json: bool = False,
apply_marker_masking: bool = True, color_augmentation: bool = False,
manifest_file: str = None):
"""
Load frames and their corner annotations.
Args:
data_dir: Directory containing training frames
image_size: Target image size
heatmap_size: Size of output heatmaps
heatmap_sigma: Gaussian sigma for heatmap generation
use_json: Output JSON progress events
apply_marker_masking: Whether to mask ArUco markers (recommended True)
color_augmentation: Whether to apply color augmentation (brightness/contrast/saturation)
manifest_file: Optional path to manifest JSON for filtered training
Returns:
images: Array of images (N, H, W, 3) normalized to [0, 1]
heatmaps: Array of heatmaps (N, heatmap_size, heatmap_size, 4)
coords: Array of coordinates (N, 4, 2) normalized to [0, 1]
"""
from PIL import Image as PILImage
data_path = Path(data_dir)
if not data_path.exists():
emit_progress("error", {
"message": f"Data directory not found: {data_dir}",
"hint": "Collect boundary training data using the training wizard"
}, use_json)
sys.exit(1)
# Check if using manifest for filtered training
if manifest_file:
emit_progress("status", {
"message": f"Loading filtered dataset from manifest: {manifest_file}",
"phase": "loading"
}, use_json)
try:
with open(manifest_file, 'r') as f:
manifest = json.load(f)
except Exception as e:
emit_progress("error", {
"message": f"Failed to load manifest file: {e}",
}, use_json)
sys.exit(1)
manifest_items = manifest.get('items', [])
if not manifest_items:
emit_progress("error", {
"message": "Manifest contains no items",
}, use_json)
sys.exit(1)
# Convert manifest items to PNG file paths
png_files = []
for item in manifest_items:
if item.get('type') != 'boundary':
continue
device_id = item.get('deviceId')
base_name = item.get('baseName')
if not device_id or not base_name:
continue
# Construct path: data_dir/deviceId/baseName.png
img_path = data_path / device_id / f"{base_name}.png"
if img_path.exists():
png_files.append(img_path)
else:
emit_progress("status", {
"message": f"Warning: Missing file from manifest: {img_path}",
"phase": "loading"
}, use_json)
total_files = len(png_files)
emit_progress("status", {
"message": f"Found {total_files} files from manifest (out of {len(manifest_items)} items)",
"phase": "loading"
}, use_json)
else:
# Find all PNG files (original behavior)
png_files = list(data_path.glob("**/*.png"))
total_files = len(png_files)
emit_progress("loading_progress", {
"step": "scanning",
"current": 0,
"total": total_files,
"message": f"Found {total_files} files to process...",
"phase": "loading"
}, use_json)
raw_samples = []
skipped = 0
for idx, png_path in enumerate(png_files):
# Emit progress every 10 files or at start/end
if idx % 10 == 0 or idx == total_files - 1:
emit_progress("loading_progress", {
"step": "loading_raw",
"current": idx + 1,
"total": total_files,
"message": f"Loading raw frames... {idx + 1}/{total_files}",
"phase": "loading"
}, use_json)
json_path = png_path.with_suffix(".json")
if not json_path.exists():
skipped += 1
continue
try:
with open(json_path, "r") as f:
annotation = json.load(f)
corners = annotation.get("corners", {})
if not all(k in corners for k in ["topLeft", "topRight", "bottomLeft", "bottomRight"]):
skipped += 1
continue
# Load image
img = PILImage.open(png_path).convert("RGB")
img_array = np.array(img)
# Extract keypoints in order: TL, TR, BL, BR
keypoints = [
(corners["topLeft"]["x"], corners["topLeft"]["y"]),
(corners["topRight"]["x"], corners["topRight"]["y"]),
(corners["bottomLeft"]["x"], corners["bottomLeft"]["y"]),
(corners["bottomRight"]["x"], corners["bottomRight"]["y"]),
]
# Apply marker masking to prevent model from learning marker patterns
# The masking uses the corner positions to locate and obscure the ArUco markers
if apply_marker_masking:
try:
img_array = mask_markers(img_array, keypoints, method="noise")
except Exception as mask_err:
# If masking fails, log warning but continue with unmasked image
emit_progress("status", {
"message": f"Warning: marker masking failed for {png_path.name}: {mask_err}",
"phase": "loading"
}, use_json)
raw_samples.append((img_array, keypoints))
except Exception as e:
emit_progress("status", {
"message": f"Error loading {png_path}: {e}",
"phase": "loading"
}, use_json)
skipped += 1
if not raw_samples:
emit_progress("error", {
"message": "No valid frames loaded",
"hint": "Ensure frames have matching .json annotation files with corner data"
}, use_json)
sys.exit(1)
# Process samples
images = []
all_heatmaps = []
all_coords = []
total_samples = len(raw_samples)
augment_factor = 9 if color_augmentation else 1 # 9 augmented versions per original
emit_progress("loading_progress", {
"step": "processing",
"current": 0,
"total": total_samples,
"message": f"Processing {total_samples} frames" + (f" with color augmentation (~{total_samples * augment_factor} total)..." if color_augmentation else "..."),
"phase": "loading"
}, use_json)
for idx, (img_array, keypoints) in enumerate(raw_samples):
# Generate heatmaps and coords (same for all augmented versions)
heatmaps = generate_heatmaps(keypoints, heatmap_size, heatmap_sigma)
coords = np.array(keypoints, dtype=np.float32)
if color_augmentation:
# Apply color augmentation (returns 9 versions including original)
augmented_images = apply_color_augmentation(img_array)
for aug_img in augmented_images:
# Resize and normalize
resized_img, _ = resize_sample(aug_img, keypoints, image_size)
resized_img = resized_img.astype(np.float32) / 255.0
images.append(resized_img)
all_heatmaps.append(heatmaps)
all_coords.append(coords)
else:
# No augmentation - just resize and normalize
resized_img, _ = resize_sample(img_array, keypoints, image_size)
resized_img = resized_img.astype(np.float32) / 255.0
images.append(resized_img)
all_heatmaps.append(heatmaps)
all_coords.append(coords)
# Emit progress every 10 samples
if (idx + 1) % 10 == 0 or idx == total_samples - 1:
emit_progress("loading_progress", {
"step": "processing",
"current": idx + 1,
"total": total_samples,
"message": f"Processing... {idx + 1}/{total_samples}" + (f" ({len(images)} with augmentation)" if color_augmentation else ""),
"phase": "loading"
}, use_json)
emit_progress("loading_progress", {
"step": "finalizing",
"current": total_samples,
"total": total_samples,
"message": "Converting to tensors...",
"phase": "loading"
}, use_json)
images = np.array(images)
all_heatmaps = np.array(all_heatmaps)
all_coords = np.array(all_coords)
emit_progress("dataset_loaded", {
"total_frames": len(images),
"raw_frames": len(raw_samples),
"skipped": skipped,
"device_count": 1, # UI expects this
"image_shape": list(images.shape),
"heatmap_shape": list(all_heatmaps.shape),
"coords_shape": list(all_coords.shape),
"marker_masking_enabled": apply_marker_masking,
"color_augmentation_enabled": color_augmentation,
"augment_factor": augment_factor,
"using_manifest": manifest_file is not None,
"phase": "loading"
}, use_json)
return images, all_heatmaps, all_coords
# =============================================================================
# Model Architecture
# =============================================================================
def create_model(input_shape: Tuple[int, int, int] = (224, 224, 3),
heatmap_size: int = 56,
num_corners: int = 4):
"""
Create heatmap regression model with MobileNetV2 backbone and feature fusion.
Architecture:
- MobileNetV2 backbone (pretrained)
- Multi-scale feature fusion (simplified BiFPN)
- Heatmap output heads
"""
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# Input
inputs = keras.Input(shape=input_shape, name="input_image")
# Preprocess for MobileNetV2 (expects [-1, 1])
x = layers.Rescaling(scale=2.0, offset=-1.0)(inputs)
# MobileNetV2 backbone
backbone = keras.applications.MobileNetV2(
input_shape=input_shape,
include_top=False,
weights="imagenet",
alpha=1.0, # Full width for better features
)
# Get multi-scale features
layer_names = [
"block_3_expand_relu", # 56x56, 96 channels
"block_6_expand_relu", # 28x28, 144 channels
"block_13_expand_relu", # 14x14, 576 channels
]
# Create a model that outputs intermediate features
backbone_outputs = [backbone.get_layer(name).output for name in layer_names]
feature_extractor = keras.Model(inputs=backbone.input, outputs=backbone_outputs)
# Freeze early layers, fine-tune later layers
for layer in backbone.layers[:100]:
layer.trainable = False
for layer in backbone.layers[100:]:
layer.trainable = True
# Extract features
features = feature_extractor(x)
f1, f2, f3 = features # 56x56, 28x28, 14x14
# Feature fusion (simplified BiFPN-style)
# Process f3 (14x14)
f3_conv = layers.Conv2D(128, 1, padding="same", activation="relu")(f3)
f3_up = layers.UpSampling2D(size=(2, 2))(f3_conv) # 28x28
# Combine with f2 (28x28)
f2_conv = layers.Conv2D(128, 1, padding="same", activation="relu")(f2)
f2_combined = layers.Add()([f2_conv, f3_up])
f2_combined = layers.Conv2D(128, 3, padding="same", activation="relu")(f2_combined)
f2_up = layers.UpSampling2D(size=(2, 2))(f2_combined) # 56x56
# Combine with f1 (56x56)
f1_conv = layers.Conv2D(128, 1, padding="same", activation="relu")(f1)
f1_combined = layers.Add()([f1_conv, f2_up])
f1_combined = layers.Conv2D(128, 3, padding="same", activation="relu")(f1_combined)
# Heatmap heads
x = layers.Conv2D(128, 3, padding="same", activation="relu")(f1_combined)
x = layers.BatchNormalization()(x)
x = layers.Conv2D(64, 3, padding="same", activation="relu")(x)
x = layers.BatchNormalization()(x)
# Output heatmaps (one per corner)
heatmaps = layers.Conv2D(
num_corners,
1,
padding="same",
activation="sigmoid",
name="heatmaps"
)(x)
# Resize heatmaps to target size if needed
if heatmap_size != 56:
heatmaps = layers.Resizing(heatmap_size, heatmap_size, name="resize_heatmaps")(heatmaps)
model = keras.Model(inputs=inputs, outputs=heatmaps, name="boundary_detector")
return model
# =============================================================================
# Loss Functions
# =============================================================================
def adaptive_wing_loss(y_true, y_pred, omega=14, epsilon=1, alpha=2.1, theta=0.5):
"""
Adaptive Wing Loss for heatmap regression.
From: "Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression" (ICCV 2019)
"""
import tensorflow as tf
delta = tf.abs(y_true - y_pred)
# Adaptive weight based on ground truth
weight = y_true ** alpha
A = omega * (1 / (1 + tf.pow(theta / epsilon, alpha - y_true))) * \
(alpha - y_true) * tf.pow(theta / epsilon, alpha - y_true - 1) / epsilon
C = theta * A - omega * tf.math.log(1 + tf.pow(theta / epsilon, alpha - y_true))
loss = tf.where(
delta < theta,
omega * tf.math.log(1 + tf.pow(delta / epsilon, alpha - y_true)),
A * delta - C
)
weighted_loss = weight * loss + (1 - weight) * loss * 0.1
return tf.reduce_mean(weighted_loss)
def dsnt_decode(heatmaps):
"""
Differentiable Spatial to Numerical Transform.
Converts heatmaps to normalized coordinates in a fully differentiable way.
"""
import tensorflow as tf
shape = tf.shape(heatmaps)
batch_size = shape[0]
height = shape[1]
width = shape[2]
num_keypoints = shape[3]
# Create coordinate grids (0 to 1)
x_range = tf.linspace(0.0, 1.0, width)
y_range = tf.linspace(0.0, 1.0, height)
# Reshape for broadcasting
x_coords = tf.reshape(x_range, [1, 1, width, 1])
y_coords = tf.reshape(y_range, [1, height, 1, 1])
# Normalize heatmaps with softmax
heatmaps_flat = tf.reshape(heatmaps, [batch_size, height * width, num_keypoints])
heatmaps_softmax = tf.nn.softmax(heatmaps_flat * 10, axis=1)
heatmaps_normalized = tf.reshape(heatmaps_softmax, [batch_size, height, width, num_keypoints])
# Compute expected coordinates (soft-argmax)
x = tf.reduce_sum(x_coords * heatmaps_normalized, axis=[1, 2])
y = tf.reduce_sum(y_coords * heatmaps_normalized, axis=[1, 2])
coords = tf.stack([x, y], axis=-1)
return coords
def coordinate_loss(y_true_coords, pred_heatmaps):
"""Compute coordinate loss using DSNT decoded coordinates."""
import tensorflow as tf
pred_coords = dsnt_decode(pred_heatmaps)
# Smooth L1 loss (Huber loss)
diff = tf.abs(y_true_coords - pred_coords)
loss = tf.where(diff < 1.0, 0.5 * diff ** 2, diff - 0.5)
return tf.reduce_mean(loss)
def convexity_loss(pred_heatmaps):
"""Regularization loss to encourage convex quadrilaterals."""
import tensorflow as tf
coords = dsnt_decode(pred_heatmaps) # (batch, 4, 2)
# Corner order: TL, TR, BL, BR -> reorder to TL, TR, BR, BL for polygon
tl = coords[:, 0, :]
tr = coords[:, 1, :]
bl = coords[:, 2, :]
br = coords[:, 3, :]
corners = tf.stack([tl, tr, br, bl], axis=1)
def cross_product_2d(v1, v2):
return v1[:, 0] * v2[:, 1] - v1[:, 1] * v2[:, 0]
total_penalty = 0.0
for i in range(4):
p1 = corners[:, i, :]
p2 = corners[:, (i + 1) % 4, :]
p3 = corners[:, (i + 2) % 4, :]
v1 = p2 - p1
v2 = p3 - p2
cross = cross_product_2d(v1, v2)
total_penalty = total_penalty + tf.reduce_mean(tf.nn.relu(-cross))
return total_penalty
# =============================================================================
# Training
# =============================================================================
def check_early_stop_signal(stop_file: str = None) -> bool:
"""Check if the user has requested early stop via a signal file."""
if stop_file is None:
return False
try:
stop_path = Path(stop_file)
if stop_path.exists():
# Remove the signal file and return True
stop_path.unlink()
return True
except Exception:
pass
return False
def train_model(X_train, heatmaps_train, coords_train,
X_val, heatmaps_val, coords_val,
epochs=100, batch_size=16, heatmap_size=56,
use_json=False, stop_file=None):
"""Train the boundary detection model with custom training loop."""
import tensorflow as tf
from tensorflow import keras
emit_progress("status", {
"message": "Creating model (loading MobileNetV2 backbone)...",
"phase": "training",
"step": "model_creation",
}, use_json)
model = create_model(
input_shape=X_train.shape[1:],
heatmap_size=heatmap_size,
num_corners=4
)
if not use_json:
model.summary()
emit_progress("status", {
"message": "Preparing training pipeline...",
"phase": "training",
"step": "pipeline_setup",
"total_epochs": epochs,
"batch_size": batch_size,
}, use_json)
# Optimizer with cosine decay
lr_schedule = keras.optimizers.schedules.CosineDecay(
initial_learning_rate=1e-3,
decay_steps=epochs * len(X_train) // batch_size,
alpha=1e-6
)
optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
# Datasets
train_ds = tf.data.Dataset.from_tensor_slices((X_train, heatmaps_train, coords_train))
train_ds = train_ds.shuffle(len(X_train)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((X_val, heatmaps_val, coords_val))
val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
history = {"loss": [], "val_loss": [], "coord_mae": [], "val_coord_mae": []}
best_val_loss = float('inf')
best_weights = None
patience = 8 # Stop after 8 epochs without meaningful improvement
patience_counter = 0
min_delta = 0.001 # Minimum improvement to reset patience
@tf.function
def train_step(images, heatmaps_true, coords_true):
with tf.GradientTape() as tape:
heatmaps_pred = model(images, training=True)
heatmap_loss = adaptive_wing_loss(heatmaps_true, heatmaps_pred)
coord_loss = coordinate_loss(coords_true, heatmaps_pred)
conv_loss = convexity_loss(heatmaps_pred)
total_loss = heatmap_loss + 0.5 * coord_loss + 0.01 * conv_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
pred_coords = dsnt_decode(heatmaps_pred)
coord_mae = tf.reduce_mean(tf.abs(coords_true - pred_coords))
return total_loss, coord_mae
@tf.function
def val_step(images, heatmaps_true, coords_true):
heatmaps_pred = model(images, training=False)
heatmap_loss = adaptive_wing_loss(heatmaps_true, heatmaps_pred)
coord_loss = coordinate_loss(coords_true, heatmaps_pred)
conv_loss = convexity_loss(heatmaps_pred)
total_loss = heatmap_loss + 0.5 * coord_loss + 0.01 * conv_loss
pred_coords = dsnt_decode(heatmaps_pred)
coord_mae = tf.reduce_mean(tf.abs(coords_true - pred_coords))
return total_loss, coord_mae
emit_progress("status", {
"message": "Compiling TensorFlow graph (first epoch may be slow)...",
"phase": "training",
"step": "graph_compilation",
}, use_json)
# Calculate total batches for progress reporting
train_batches = (len(X_train) + batch_size - 1) // batch_size # Ceiling division
for epoch in range(epochs):
# Check for user-requested early stop at start of each epoch
if check_early_stop_signal(stop_file):
emit_progress("status", {
"message": f"User requested early stop at epoch {epoch + 1}. Saving best model...",
"phase": "training",
"early_graduation": True
}, use_json)
break
train_losses, train_maes = [], []
for batch_idx, (images, heatmaps_true, coords_true) in enumerate(train_ds):
# First batch of first epoch triggers graph compilation
if epoch == 0 and batch_idx == 0:
emit_progress("status", {
"message": "Running first batch (compiling graph)...",
"phase": "training",
"step": "first_batch",
}, use_json)
loss, mae = train_step(images, heatmaps_true, coords_true)
train_losses.append(loss.numpy())
train_maes.append(mae.numpy())
# After first batch, confirm compilation is done
if epoch == 0 and batch_idx == 0:
emit_progress("status", {
"message": f"Graph compiled. Training epoch 1/{epochs}...",
"phase": "training",
"step": "training_started",
}, use_json)
# Report batch progress during first epoch (every 10 batches)
if epoch == 0 and batch_idx > 0 and batch_idx % 10 == 0:
emit_progress("status", {
"message": f"Epoch 1/{epochs}: batch {batch_idx + 1}/{train_batches}",
"phase": "training",
"step": "batch_progress",
}, use_json)
# Validation phase
if epoch == 0:
emit_progress("status", {
"message": f"Epoch 1/{epochs}: running validation (compiling validation graph)...",
"phase": "training",
"step": "validation_compile",
}, use_json)
val_losses, val_maes = [], []
for batch_idx, (images, heatmaps_true, coords_true) in enumerate(val_ds):
loss, mae = val_step(images, heatmaps_true, coords_true)
val_losses.append(loss.numpy())
val_maes.append(mae.numpy())
# After first validation batch of first epoch
if epoch == 0 and batch_idx == 0:
emit_progress("status", {
"message": f"Epoch 1/{epochs}: validation in progress...",
"phase": "training",
"step": "validation_progress",
}, use_json)
train_loss = np.mean(train_losses)
val_loss = np.mean(val_losses)
train_mae = np.mean(train_maes)
val_mae = np.mean(val_maes)
history["loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["coord_mae"].append(train_mae)
history["val_coord_mae"].append(val_mae)
# Convert normalized MAE to pixel error (image is 224x224)
pixel_error = float(val_mae) * 224.0
# Generate inference samples for visualization (5 random validation samples)
if epoch == 0:
emit_progress("status", {
"message": f"Epoch 1/{epochs}: generating sample visualizations...",
"phase": "training",
"step": "generating_samples",
}, use_json)
inference_samples = generate_inference_samples(model, X_val, coords_val, num_samples=5)
emit_progress("epoch", {
"epoch": epoch + 1,
"total_epochs": epochs,
"loss": float(train_loss),
"val_loss": float(val_loss),
"coord_mae": float(train_mae),
"val_coord_mae": float(val_mae),
"accuracy": 1.0 - float(val_mae),
"val_accuracy": 1.0 - float(val_mae),
"val_pixel_error": pixel_error, # Mean corner error in pixels
"inference_samples": inference_samples, # 5 sample inferences for visualization
"phase": "training"
}, use_json)
# Only count as improvement if loss decreased by at least min_delta
if val_loss < best_val_loss - min_delta:
best_val_loss = val_loss
patience_counter = 0
best_weights = model.get_weights()
else:
patience_counter += 1
if patience_counter >= patience:
emit_progress("status", {
"message": f"Early stopping at epoch {epoch + 1} (no improvement for {patience} epochs)",
"phase": "training"
}, use_json)
break
if best_weights:
model.set_weights(best_weights)
return model, history
# =============================================================================
# Export
# =============================================================================
def run_subprocess_with_streaming(cmd: list, use_json: bool, timeout_seconds: int = 300) -> tuple[int, str, str]:
"""
Run a subprocess with streaming output and timeout.
Returns (returncode, stdout, stderr).
Emits progress messages as output is received.
"""
import subprocess
import select
import time
emit_progress("status", {"message": f"Running: {' '.join(cmd)}", "phase": "exporting"}, use_json)
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1, # Line buffered
)
stdout_lines = []
stderr_lines = []
start_time = time.time()
try:
while process.poll() is None:
# Check timeout
elapsed = time.time() - start_time
if elapsed > timeout_seconds:
process.kill()
emit_progress("status", {
"message": f"Process timed out after {timeout_seconds}s",
"phase": "exporting"
}, use_json)
return -1, '\n'.join(stdout_lines), f"Timeout after {timeout_seconds}s"
# Use select to check for available output (non-blocking)
if sys.platform != 'win32':
readable, _, _ = select.select([process.stdout, process.stderr], [], [], 1.0)
for stream in readable:
line = stream.readline()
if line:
line = line.rstrip()
if stream == process.stdout:
stdout_lines.append(line)
emit_progress("status", {
"message": f"[converter] {line[:200]}",
"phase": "exporting"
}, use_json)
else:
stderr_lines.append(line)
if line.strip(): # Only emit non-empty stderr
emit_progress("status", {
"message": f"[converter stderr] {line[:200]}",
"phase": "exporting"
}, use_json)
else:
# Windows fallback - just wait a bit
time.sleep(1.0)
emit_progress("status", {
"message": f"Converting... ({int(elapsed)}s elapsed)",
"phase": "exporting"
}, use_json)
# Read any remaining output
remaining_stdout, remaining_stderr = process.communicate(timeout=10)
if remaining_stdout:
stdout_lines.extend(remaining_stdout.split('\n'))
if remaining_stderr:
stderr_lines.extend(remaining_stderr.split('\n'))
except subprocess.TimeoutExpired:
process.kill()
return -1, '\n'.join(stdout_lines), "Process killed due to timeout"
return process.returncode, '\n'.join(stdout_lines), '\n'.join(stderr_lines)
def export_to_tfjs(model, output_dir: str, use_json: bool = False):
"""Export model to TensorFlow.js format."""
import subprocess
import tempfile
import tensorflow as tf
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
emit_progress("status", {
"message": "Exporting to TensorFlow.js format...",
"phase": "exporting"
}, use_json)
for f in output_path.glob("*.bin"):
f.unlink()
model_json = output_path / "model.json"
if model_json.exists():
model_json.unlink()
with tempfile.TemporaryDirectory() as tmpdir:
saved_model_path = Path(tmpdir) / "saved_model"
@tf.function(input_signature=[tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.float32)])
def serve(x):
return model(x, training=False)
tf.saved_model.save(model, str(saved_model_path), signatures={"serving_default": serve})
cmd = [
sys.executable, "-m", "tensorflowjs.converters.converter",
"--input_format=tf_saved_model",
"--output_format=tfjs_graph_model",
"--signature_name=serving_default",
str(saved_model_path),
str(output_path),
]
# Use streaming subprocess with 5 minute timeout
returncode, stdout, stderr = run_subprocess_with_streaming(cmd, use_json, timeout_seconds=300)
if returncode != 0:
emit_progress("status", {
"message": f"TensorFlow.js conversion warning: {stderr[-500:] if stderr else 'unknown'}",
"phase": "exporting"
}, use_json)
model_json_path = output_path / "model.json"
if model_json_path.exists():
weights_bin = list(output_path.glob("*.bin"))
total_size = model_json_path.stat().st_size
for w in weights_bin:
total_size += w.stat().st_size
emit_progress("exported", {
"output_dir": str(output_path),
"model_size_mb": round(total_size / 1024 / 1024, 2),
"phase": "exporting"
}, use_json)
def main():
args = parse_args()
use_json = args.json_progress
# Record training start time
training_start_time = time.time()
training_start_iso = datetime.now().isoformat()
if not use_json:
print("=" * 60)
print("Abacus Boundary Detector Training (Heatmap + DSNT)")
print("=" * 60)
# Gather hardware and environment info
hardware_info = get_hardware_info()
environment_info = get_environment_info()
# Emit training_started event with all metadata
emit_progress("training_started", {
"session_id": args.session_id,
"model_type": "boundary-detector",
"started_at": training_start_iso,
"config": {
"epochs": args.epochs,
"batch_size": args.batch_size,
"validation_split": args.validation_split,
"heatmap_size": args.heatmap_size,
"heatmap_sigma": args.heatmap_sigma,
"color_augmentation": args.color_augmentation,
"marker_masking": not args.no_marker_masking,
},
"hardware": hardware_info,
"environment": environment_info,
"phase": "setup",
}, use_json)
# Check TensorFlow
try:
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
mps_devices = tf.config.list_physical_devices("MPS")
device = "MPS (Apple Silicon)" if mps_devices else ("GPU" if gpus else "CPU")
emit_progress("status", {
"message": f"TensorFlow {tf.__version__} - Using {device}",
"phase": "setup",
}, use_json)
except ImportError:
emit_progress("error", {
"message": "TensorFlow not installed",
"hint": "Install with: pip install tensorflow"
}, use_json)
sys.exit(1)
# Check tensorflowjs
tfjs_available = False
try:
import tensorflowjs
tfjs_available = True
except ImportError:
emit_progress("status", {
"message": "TensorFlow.js converter not available - will skip JS export",
"phase": "setup"
}, use_json)
# Load dataset
apply_masking = not args.no_marker_masking
if apply_masking:
emit_progress("status", {
"message": "Marker masking enabled - ArUco markers will be obscured in training data",
"phase": "loading"
}, use_json)
else:
emit_progress("status", {
"message": "WARNING: Marker masking disabled - model may learn to detect markers instead of frame edges",
"phase": "loading"
}, use_json)
if args.color_augmentation:
emit_progress("status", {
"message": "Color augmentation enabled - will generate ~9x training samples with brightness/contrast/saturation variations",
"phase": "loading"
}, use_json)
images, heatmaps, coords = load_dataset(
args.data_dir,
image_size=224,
heatmap_size=args.heatmap_size,
heatmap_sigma=args.heatmap_sigma,
use_json=use_json,
apply_marker_masking=apply_masking,
color_augmentation=args.color_augmentation,
manifest_file=args.manifest_file,
)
if len(images) < 20:
emit_progress("error", {
"message": f"Insufficient training data: {len(images)} samples (need at least 20)",
"hint": "Collect more boundary frames using the training wizard"
}, use_json)
sys.exit(1)
# Split
from sklearn.model_selection import train_test_split
X_train, X_val, hm_train, hm_val, c_train, c_val = train_test_split(
images, heatmaps, coords,
test_size=args.validation_split,
random_state=42
)
emit_progress("status", {
"message": f"Split: {len(X_train)} training, {len(X_val)} validation",
"phase": "loading",
}, use_json)
# Train
model, history = train_model(
X_train, hm_train, c_train,
X_val, hm_val, c_val,
epochs=args.epochs,
batch_size=args.batch_size,
heatmap_size=args.heatmap_size,
use_json=use_json,
stop_file=args.stop_file,
)
# Final evaluation
import tensorflow as tf
val_ds = tf.data.Dataset.from_tensor_slices((X_val, hm_val, c_val)).batch(args.batch_size)
all_maes = []
for images_batch, hm_true, coords_true in val_ds:
hm_pred = model(images_batch, training=False)
pred_coords = dsnt_decode(hm_pred)
mae = tf.reduce_mean(tf.abs(coords_true - pred_coords))
all_maes.append(mae.numpy())
final_mae = np.mean(all_maes)
# Save
output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
keras_path = output_path / "boundary-detector.keras"
model.save(keras_path)
emit_progress("status", {
"message": f"Keras model saved to: {keras_path}",
"phase": "saving"
}, use_json)
tfjs_exported = False
if tfjs_available:
try:
export_to_tfjs(model, args.output_dir, use_json)
# Verify the export actually created model.json
model_json_path = output_path / "model.json"
if model_json_path.exists():
tfjs_exported = True
else:
emit_progress("status", {
"message": "WARNING: TensorFlow.js export completed but model.json not found",
"phase": "exporting"
}, use_json)
except Exception as e:
emit_progress("status", {
"message": f"TensorFlow.js export failed: {str(e)}",
"phase": "exporting"
}, use_json)
else:
emit_progress("status", {
"message": "TensorFlow.js converter not available - skipping browser export",
"phase": "exporting"
}, use_json)
# Save config
preprocessing_config = {
"model_type": "heatmap_dsnt",
"input_size": 224,
"heatmap_size": args.heatmap_size,
"num_corners": 4,
"corner_order": ["topLeft", "topRight", "bottomLeft", "bottomRight"],
"trained_at": __import__("datetime").datetime.now().isoformat(),
"training_samples": len(images),
"final_coord_mae": float(final_mae),
}
preprocessing_path = output_path / "preprocessing.json"
with open(preprocessing_path, "w") as f:
json.dump(preprocessing_config, f, indent=2)
# Convert MAE to pixel error for display (more intuitive than normalized MAE)
final_pixel_error = float(final_mae) * 224.0
# Calculate training duration
training_end_time = time.time()
training_end_iso = datetime.now().isoformat()
training_duration_seconds = training_end_time - training_start_time
# Build epoch history for graph
epoch_history = []
for i in range(len(history["loss"])):
epoch_history.append({
"epoch": i + 1,
"loss": float(history["loss"][i]),
"val_loss": float(history["val_loss"][i]),
"coord_mae": float(history["coord_mae"][i]),
"val_coord_mae": float(history["val_coord_mae"][i]),
"val_pixel_error": float(history["val_coord_mae"][i]) * 224.0,
})
emit_progress("complete", {
"final_accuracy": float(1.0 - final_mae), # Legacy field for compatibility
"final_loss": float(history["val_loss"][-1]) if history["val_loss"] else 0,
"final_mae": float(final_mae),
"final_pixel_error": final_pixel_error, # Average corner error in pixels
"epochs_trained": len(history["loss"]),
"output_dir": args.output_dir,
"model_type": "heatmap_dsnt",
"tfjs_exported": tfjs_exported, # Whether browser model was successfully created
"session_id": args.session_id, # Session ID for database tracking
# Timing info
"started_at": training_start_iso,
"completed_at": training_end_iso,
"training_duration_seconds": training_duration_seconds,
# Full epoch history for graphs
"epoch_history": epoch_history,
# Hardware and environment (repeat for complete event)
"hardware": hardware_info,
"environment": environment_info,
"phase": "complete"
}, use_json)
if __name__ == "__main__":
main()