feat(vision): add model tester to training results

After training completes, users can now test the model directly:
- New "Test Model" tab in results phase
- Uses CameraCapture with marker detection
- Runs inference at ~10 FPS and shows detected value + confidence
- Color-coded confidence display (green > 80%, yellow > 50%, red otherwise)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Thomas Hallock 2026-01-06 14:45:14 -06:00
parent 94f53d8097
commit 6fb8f71565
4 changed files with 373 additions and 74 deletions

View File

@ -569,7 +569,20 @@
"Bash(brew list:*)",
"Bash(scripts/train-column-classifier/.venv/bin/pip index:*)",
"Bash(scripts/train-column-classifier/.venv/bin/pip:*)",
"WebFetch(domain:pypi.org)"
"WebFetch(domain:pypi.org)",
"Bash(\"/Users/antialias/projects/soroban-abacus-flashcards/apps/web/data/vision-training/.venv/bin/python\" -m pip install Pillow scikit-learn)",
"Bash(src/app/api/vision-training/config.ts )",
"Bash(src/app/api/vision-training/preflight/ )",
"Bash(src/app/vision-training/train/components/wizard/CardCarousel.tsx )",
"Bash(src/app/vision-training/train/components/wizard/ExpandedCard.tsx )",
"Bash(src/app/vision-training/train/components/wizard/PhaseSection.tsx )",
"Bash(src/app/vision-training/train/components/wizard/TrainingWizard.tsx )",
"Bash(src/app/vision-training/train/components/wizard/cards/DataCard.tsx )",
"Bash(src/app/vision-training/train/components/wizard/cards/DependencyCard.tsx )",
"Bash(src/app/vision-training/train/components/wizard/types.ts )",
"Bash(src/app/vision-training/train/components/TrainingDataCapture.tsx )",
"Bash(src/app/vision-training/train/page.tsx )",
"Bash(src/components/vision/CameraCapture.tsx)"
],
"deny": [],
"ask": []

View File

@ -0,0 +1,225 @@
'use client'
import { useCallback, useRef, useState } from 'react'
import { css } from '../../../../../styled-system/css'
import { CameraCapture, type CameraSource } from '@/components/vision/CameraCapture'
import { useColumnClassifier } from '@/hooks/useColumnClassifier'
import type { CalibrationGrid } from '@/types/vision'
interface ModelTesterProps {
/** Number of physical abacus columns (default 4) */
columnCount?: number
}
/**
* Test the trained model with live camera feed
*
* Shows the camera with marker detection, runs inference on each frame,
* and displays the detected value with confidence.
*/
export function ModelTester({ columnCount = 4 }: ModelTesterProps) {
const [cameraSource, setCameraSource] = useState<CameraSource>('local')
const [isPhoneConnected, setIsPhoneConnected] = useState(false)
const [calibration, setCalibration] = useState<CalibrationGrid | null>(null)
const [detectedValue, setDetectedValue] = useState<number | null>(null)
const [confidence, setConfidence] = useState<number>(0)
const [isRunning, setIsRunning] = useState(false)
const captureElementRef = useRef<HTMLImageElement | HTMLVideoElement | null>(null)
const inferenceLoopRef = useRef<number | null>(null)
const lastInferenceTimeRef = useRef<number>(0)
const classifier = useColumnClassifier()
// Handle capture from camera
const handleCapture = useCallback((element: HTMLImageElement | HTMLVideoElement) => {
captureElementRef.current = element
}, [])
// Run inference on current frame
const runInference = useCallback(async () => {
const element = captureElementRef.current
if (!element) return
// For video, check if ready
if (element instanceof HTMLVideoElement && element.readyState < 2) return
// For image, check if loaded
if (element instanceof HTMLImageElement && (!element.complete || element.naturalWidth === 0))
return
try {
// Import frame processor dynamically
const { processImageFrame } = await import('@/lib/vision/frameProcessor')
// For video, draw to temp image
let imageElement: HTMLImageElement
if (element instanceof HTMLVideoElement) {
const canvas = document.createElement('canvas')
canvas.width = element.videoWidth
canvas.height = element.videoHeight
const ctx = canvas.getContext('2d')
if (!ctx) return
ctx.drawImage(element, 0, 0)
imageElement = new Image()
imageElement.src = canvas.toDataURL('image/jpeg')
await new Promise((resolve, reject) => {
imageElement.onload = resolve
imageElement.onerror = reject
})
} else {
imageElement = element
}
// Slice into columns (using calibration if available)
const columnImages = processImageFrame(imageElement, calibration, columnCount)
if (columnImages.length === 0) return
// Run classification
const result = await classifier.classifyColumns(columnImages)
if (result) {
// Combine digits into number
const value = result.digits.reduce((acc, d) => acc * 10 + d, 0)
const minConfidence = Math.min(...result.confidences)
setDetectedValue(value)
setConfidence(minConfidence)
}
} catch (err) {
console.error('[ModelTester] Inference error:', err)
}
}, [calibration, columnCount, classifier])
// Start/stop inference loop
const toggleTesting = useCallback(() => {
if (isRunning) {
// Stop
if (inferenceLoopRef.current) {
cancelAnimationFrame(inferenceLoopRef.current)
inferenceLoopRef.current = null
}
setIsRunning(false)
setDetectedValue(null)
setConfidence(0)
} else {
// Start
setIsRunning(true)
const loop = () => {
const now = performance.now()
// Run inference at ~10 FPS
if (now - lastInferenceTimeRef.current > 100) {
lastInferenceTimeRef.current = now
runInference()
}
inferenceLoopRef.current = requestAnimationFrame(loop)
}
loop()
}
}, [isRunning, runInference])
// Check if camera is ready
const canTest = cameraSource === 'local' || isPhoneConnected
return (
<div
data-component="model-tester"
className={css({
p: 3,
bg: 'purple.900/20',
border: '1px solid',
borderColor: 'purple.700/50',
borderRadius: 'lg',
})}
>
{/* Header */}
<div className={css({ display: 'flex', alignItems: 'center', gap: 2, mb: 3 })}>
<span>🔬</span>
<span className={css({ fontWeight: 'medium', color: 'purple.300' })}>Test Model</span>
{classifier.isModelLoaded && (
<span className={css({ fontSize: 'xs', color: 'green.400', ml: 'auto' })}>
Model loaded
</span>
)}
</div>
{/* Camera */}
<CameraCapture
initialSource="local"
onCapture={handleCapture}
onSourceChange={setCameraSource}
onPhoneConnected={setIsPhoneConnected}
compact
enableMarkerDetection
columnCount={columnCount}
onCalibrationChange={setCalibration}
showRectifiedView
/>
{/* Results display */}
{isRunning && (
<div
className={css({
mt: 3,
p: 4,
bg: 'gray.900',
borderRadius: 'lg',
textAlign: 'center',
})}
>
<div
className={css({
fontSize: '4xl',
fontWeight: 'bold',
fontFamily: 'mono',
color: confidence > 0.8 ? 'green.400' : confidence > 0.5 ? 'yellow.400' : 'red.400',
mb: 1,
})}
>
{detectedValue !== null ? detectedValue : '—'}
</div>
<div className={css({ fontSize: 'sm', color: 'gray.500' })}>
Confidence: {(confidence * 100).toFixed(0)}%
</div>
</div>
)}
{/* Controls */}
{canTest && (
<div className={css({ mt: 3, display: 'flex', justifyContent: 'center' })}>
<button
type="button"
onClick={toggleTesting}
disabled={classifier.isLoading}
className={css({
px: 6,
py: 2,
bg: isRunning ? 'red.600' : 'purple.600',
color: 'white',
borderRadius: 'md',
border: 'none',
cursor: 'pointer',
fontWeight: 'medium',
_hover: { bg: isRunning ? 'red.500' : 'purple.500' },
_disabled: { opacity: 0.5, cursor: 'not-allowed' },
})}
>
{classifier.isLoading ? 'Loading...' : isRunning ? 'Stop Testing' : 'Start Testing'}
</button>
</div>
)}
{/* Calibration status */}
<div className={css({ fontSize: 'xs', color: 'gray.500', mt: 3, textAlign: 'center' })}>
{calibration ? (
<span className={css({ color: 'green.400' })}> Markers detected</span>
) : (
<span>Point camera at abacus with markers</span>
)}
</div>
</div>
)
}

View File

@ -1,6 +1,8 @@
'use client'
import { useState } from 'react'
import { css } from '../../../../../../../styled-system/css'
import { ModelTester } from '../../ModelTester'
import type { TrainingResult } from '../types'
interface ResultsCardProps {
@ -10,6 +12,8 @@ interface ResultsCardProps {
}
export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
const [activeTab, setActiveTab] = useState<'results' | 'test'>('results')
if (error) {
return (
<div className={css({ textAlign: 'center', py: 4 })}>
@ -73,87 +77,144 @@ export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
const accuracy = result.final_accuracy ?? 0
return (
<div className={css({ textAlign: 'center' })}>
{/* Success icon */}
<div className={css({ fontSize: '3xl', mb: 2 })}>🎉</div>
{/* Title */}
<div className={css({ fontSize: 'lg', fontWeight: 'bold', color: 'green.400', mb: 3 })}>
Training Complete!
</div>
{/* Main accuracy */}
<div>
{/* Tabs */}
<div
className={css({
fontSize: '4xl',
fontWeight: 'bold',
color: 'green.400',
mb: 0.5,
})}
>
{(accuracy * 100).toFixed(1)}%
</div>
<div className={css({ fontSize: 'sm', color: 'gray.500', mb: 4 })}>Final Accuracy</div>
{/* Stats grid */}
<div
className={css({
display: 'grid',
gridTemplateColumns: 'repeat(3, 1fr)',
gap: 2,
p: 3,
bg: 'gray.900',
borderRadius: 'lg',
fontSize: 'sm',
display: 'flex',
gap: 1,
mb: 4,
borderBottom: '1px solid',
borderColor: 'gray.700',
})}
>
<div>
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Epochs</div>
<div className={css({ fontFamily: 'mono', color: 'gray.300', fontWeight: 'medium' })}>
{result.epochs_trained ?? '—'}
</div>
</div>
<div>
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Final Loss</div>
<div className={css({ fontFamily: 'mono', color: 'gray.300', fontWeight: 'medium' })}>
{result.final_loss?.toFixed(4) ?? '—'}
</div>
</div>
<div>
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Model</div>
<div className={css({ color: 'green.400', fontWeight: 'medium' })}>
{result.tfjs_exported ? '✓ Saved' : '—'}
</div>
</div>
<button
type="button"
onClick={() => setActiveTab('results')}
className={css({
px: 4,
py: 2,
bg: 'transparent',
color: activeTab === 'results' ? 'green.400' : 'gray.500',
border: 'none',
borderBottom: '2px solid',
borderColor: activeTab === 'results' ? 'green.400' : 'transparent',
cursor: 'pointer',
fontWeight: 'medium',
fontSize: 'sm',
_hover: { color: 'gray.300' },
})}
>
Results
</button>
<button
type="button"
onClick={() => setActiveTab('test')}
className={css({
px: 4,
py: 2,
bg: 'transparent',
color: activeTab === 'test' ? 'purple.400' : 'gray.500',
border: 'none',
borderBottom: '2px solid',
borderColor: activeTab === 'test' ? 'purple.400' : 'transparent',
cursor: 'pointer',
fontWeight: 'medium',
fontSize: 'sm',
_hover: { color: 'gray.300' },
})}
>
🔬 Test Model
</button>
</div>
{/* Model path */}
{result.tfjs_exported && (
<div className={css({ fontSize: 'xs', color: 'gray.500', mb: 4 })}>
Model exported and ready to use
</div>
)}
{/* Tab content */}
{activeTab === 'results' ? (
<div className={css({ textAlign: 'center' })}>
{/* Success icon */}
<div className={css({ fontSize: '3xl', mb: 2 })}>🎉</div>
{/* Train again button */}
<button
type="button"
onClick={onTrainAgain}
className={css({
px: 6,
py: 3,
bg: 'blue.600',
color: 'white',
borderRadius: 'lg',
border: 'none',
cursor: 'pointer',
fontWeight: 'bold',
fontSize: 'md',
_hover: { bg: 'blue.500' },
})}
>
Train Again
</button>
{/* Title */}
<div className={css({ fontSize: 'lg', fontWeight: 'bold', color: 'green.400', mb: 3 })}>
Training Complete!
</div>
{/* Main accuracy */}
<div
className={css({
fontSize: '4xl',
fontWeight: 'bold',
color: 'green.400',
mb: 0.5,
})}
>
{(accuracy * 100).toFixed(1)}%
</div>
<div className={css({ fontSize: 'sm', color: 'gray.500', mb: 4 })}>Final Accuracy</div>
{/* Stats grid */}
<div
className={css({
display: 'grid',
gridTemplateColumns: 'repeat(3, 1fr)',
gap: 2,
p: 3,
bg: 'gray.900',
borderRadius: 'lg',
fontSize: 'sm',
mb: 4,
})}
>
<div>
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Epochs</div>
<div className={css({ fontFamily: 'mono', color: 'gray.300', fontWeight: 'medium' })}>
{result.epochs_trained ?? '—'}
</div>
</div>
<div>
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Final Loss</div>
<div className={css({ fontFamily: 'mono', color: 'gray.300', fontWeight: 'medium' })}>
{result.final_loss?.toFixed(4) ?? '—'}
</div>
</div>
<div>
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Model</div>
<div className={css({ color: 'green.400', fontWeight: 'medium' })}>
{result.tfjs_exported ? '✓ Saved' : '—'}
</div>
</div>
</div>
{/* Model path */}
{result.tfjs_exported && (
<div className={css({ fontSize: 'xs', color: 'gray.500', mb: 4 })}>
Model exported and ready to use
</div>
)}
{/* Train again button */}
<button
type="button"
onClick={onTrainAgain}
className={css({
px: 6,
py: 3,
bg: 'blue.600',
color: 'white',
borderRadius: 'lg',
border: 'none',
cursor: 'pointer',
fontWeight: 'bold',
fontSize: 'md',
_hover: { bg: 'blue.500' },
})}
>
Train Again
</button>
</div>
) : (
<ModelTester columnCount={4} />
)}
</div>
)
}