feat(vision): add model tester to training results
After training completes, users can now test the model directly: - New "Test Model" tab in results phase - Uses CameraCapture with marker detection - Runs inference at ~10 FPS and shows detected value + confidence - Color-coded confidence display (green > 80%, yellow > 50%, red otherwise) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
94f53d8097
commit
6fb8f71565
|
|
@ -569,7 +569,20 @@
|
|||
"Bash(brew list:*)",
|
||||
"Bash(scripts/train-column-classifier/.venv/bin/pip index:*)",
|
||||
"Bash(scripts/train-column-classifier/.venv/bin/pip:*)",
|
||||
"WebFetch(domain:pypi.org)"
|
||||
"WebFetch(domain:pypi.org)",
|
||||
"Bash(\"/Users/antialias/projects/soroban-abacus-flashcards/apps/web/data/vision-training/.venv/bin/python\" -m pip install Pillow scikit-learn)",
|
||||
"Bash(src/app/api/vision-training/config.ts )",
|
||||
"Bash(src/app/api/vision-training/preflight/ )",
|
||||
"Bash(src/app/vision-training/train/components/wizard/CardCarousel.tsx )",
|
||||
"Bash(src/app/vision-training/train/components/wizard/ExpandedCard.tsx )",
|
||||
"Bash(src/app/vision-training/train/components/wizard/PhaseSection.tsx )",
|
||||
"Bash(src/app/vision-training/train/components/wizard/TrainingWizard.tsx )",
|
||||
"Bash(src/app/vision-training/train/components/wizard/cards/DataCard.tsx )",
|
||||
"Bash(src/app/vision-training/train/components/wizard/cards/DependencyCard.tsx )",
|
||||
"Bash(src/app/vision-training/train/components/wizard/types.ts )",
|
||||
"Bash(src/app/vision-training/train/components/TrainingDataCapture.tsx )",
|
||||
"Bash(src/app/vision-training/train/page.tsx )",
|
||||
"Bash(src/components/vision/CameraCapture.tsx)"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -0,0 +1,225 @@
|
|||
'use client'
|
||||
|
||||
import { useCallback, useRef, useState } from 'react'
|
||||
import { css } from '../../../../../styled-system/css'
|
||||
import { CameraCapture, type CameraSource } from '@/components/vision/CameraCapture'
|
||||
import { useColumnClassifier } from '@/hooks/useColumnClassifier'
|
||||
import type { CalibrationGrid } from '@/types/vision'
|
||||
|
||||
interface ModelTesterProps {
|
||||
/** Number of physical abacus columns (default 4) */
|
||||
columnCount?: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Test the trained model with live camera feed
|
||||
*
|
||||
* Shows the camera with marker detection, runs inference on each frame,
|
||||
* and displays the detected value with confidence.
|
||||
*/
|
||||
export function ModelTester({ columnCount = 4 }: ModelTesterProps) {
|
||||
const [cameraSource, setCameraSource] = useState<CameraSource>('local')
|
||||
const [isPhoneConnected, setIsPhoneConnected] = useState(false)
|
||||
const [calibration, setCalibration] = useState<CalibrationGrid | null>(null)
|
||||
const [detectedValue, setDetectedValue] = useState<number | null>(null)
|
||||
const [confidence, setConfidence] = useState<number>(0)
|
||||
const [isRunning, setIsRunning] = useState(false)
|
||||
|
||||
const captureElementRef = useRef<HTMLImageElement | HTMLVideoElement | null>(null)
|
||||
const inferenceLoopRef = useRef<number | null>(null)
|
||||
const lastInferenceTimeRef = useRef<number>(0)
|
||||
|
||||
const classifier = useColumnClassifier()
|
||||
|
||||
// Handle capture from camera
|
||||
const handleCapture = useCallback((element: HTMLImageElement | HTMLVideoElement) => {
|
||||
captureElementRef.current = element
|
||||
}, [])
|
||||
|
||||
// Run inference on current frame
|
||||
const runInference = useCallback(async () => {
|
||||
const element = captureElementRef.current
|
||||
if (!element) return
|
||||
|
||||
// For video, check if ready
|
||||
if (element instanceof HTMLVideoElement && element.readyState < 2) return
|
||||
// For image, check if loaded
|
||||
if (element instanceof HTMLImageElement && (!element.complete || element.naturalWidth === 0))
|
||||
return
|
||||
|
||||
try {
|
||||
// Import frame processor dynamically
|
||||
const { processImageFrame } = await import('@/lib/vision/frameProcessor')
|
||||
|
||||
// For video, draw to temp image
|
||||
let imageElement: HTMLImageElement
|
||||
if (element instanceof HTMLVideoElement) {
|
||||
const canvas = document.createElement('canvas')
|
||||
canvas.width = element.videoWidth
|
||||
canvas.height = element.videoHeight
|
||||
const ctx = canvas.getContext('2d')
|
||||
if (!ctx) return
|
||||
ctx.drawImage(element, 0, 0)
|
||||
|
||||
imageElement = new Image()
|
||||
imageElement.src = canvas.toDataURL('image/jpeg')
|
||||
await new Promise((resolve, reject) => {
|
||||
imageElement.onload = resolve
|
||||
imageElement.onerror = reject
|
||||
})
|
||||
} else {
|
||||
imageElement = element
|
||||
}
|
||||
|
||||
// Slice into columns (using calibration if available)
|
||||
const columnImages = processImageFrame(imageElement, calibration, columnCount)
|
||||
|
||||
if (columnImages.length === 0) return
|
||||
|
||||
// Run classification
|
||||
const result = await classifier.classifyColumns(columnImages)
|
||||
|
||||
if (result) {
|
||||
// Combine digits into number
|
||||
const value = result.digits.reduce((acc, d) => acc * 10 + d, 0)
|
||||
const minConfidence = Math.min(...result.confidences)
|
||||
|
||||
setDetectedValue(value)
|
||||
setConfidence(minConfidence)
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('[ModelTester] Inference error:', err)
|
||||
}
|
||||
}, [calibration, columnCount, classifier])
|
||||
|
||||
// Start/stop inference loop
|
||||
const toggleTesting = useCallback(() => {
|
||||
if (isRunning) {
|
||||
// Stop
|
||||
if (inferenceLoopRef.current) {
|
||||
cancelAnimationFrame(inferenceLoopRef.current)
|
||||
inferenceLoopRef.current = null
|
||||
}
|
||||
setIsRunning(false)
|
||||
setDetectedValue(null)
|
||||
setConfidence(0)
|
||||
} else {
|
||||
// Start
|
||||
setIsRunning(true)
|
||||
|
||||
const loop = () => {
|
||||
const now = performance.now()
|
||||
// Run inference at ~10 FPS
|
||||
if (now - lastInferenceTimeRef.current > 100) {
|
||||
lastInferenceTimeRef.current = now
|
||||
runInference()
|
||||
}
|
||||
inferenceLoopRef.current = requestAnimationFrame(loop)
|
||||
}
|
||||
|
||||
loop()
|
||||
}
|
||||
}, [isRunning, runInference])
|
||||
|
||||
// Check if camera is ready
|
||||
const canTest = cameraSource === 'local' || isPhoneConnected
|
||||
|
||||
return (
|
||||
<div
|
||||
data-component="model-tester"
|
||||
className={css({
|
||||
p: 3,
|
||||
bg: 'purple.900/20',
|
||||
border: '1px solid',
|
||||
borderColor: 'purple.700/50',
|
||||
borderRadius: 'lg',
|
||||
})}
|
||||
>
|
||||
{/* Header */}
|
||||
<div className={css({ display: 'flex', alignItems: 'center', gap: 2, mb: 3 })}>
|
||||
<span>🔬</span>
|
||||
<span className={css({ fontWeight: 'medium', color: 'purple.300' })}>Test Model</span>
|
||||
{classifier.isModelLoaded && (
|
||||
<span className={css({ fontSize: 'xs', color: 'green.400', ml: 'auto' })}>
|
||||
Model loaded
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Camera */}
|
||||
<CameraCapture
|
||||
initialSource="local"
|
||||
onCapture={handleCapture}
|
||||
onSourceChange={setCameraSource}
|
||||
onPhoneConnected={setIsPhoneConnected}
|
||||
compact
|
||||
enableMarkerDetection
|
||||
columnCount={columnCount}
|
||||
onCalibrationChange={setCalibration}
|
||||
showRectifiedView
|
||||
/>
|
||||
|
||||
{/* Results display */}
|
||||
{isRunning && (
|
||||
<div
|
||||
className={css({
|
||||
mt: 3,
|
||||
p: 4,
|
||||
bg: 'gray.900',
|
||||
borderRadius: 'lg',
|
||||
textAlign: 'center',
|
||||
})}
|
||||
>
|
||||
<div
|
||||
className={css({
|
||||
fontSize: '4xl',
|
||||
fontWeight: 'bold',
|
||||
fontFamily: 'mono',
|
||||
color: confidence > 0.8 ? 'green.400' : confidence > 0.5 ? 'yellow.400' : 'red.400',
|
||||
mb: 1,
|
||||
})}
|
||||
>
|
||||
{detectedValue !== null ? detectedValue : '—'}
|
||||
</div>
|
||||
<div className={css({ fontSize: 'sm', color: 'gray.500' })}>
|
||||
Confidence: {(confidence * 100).toFixed(0)}%
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Controls */}
|
||||
{canTest && (
|
||||
<div className={css({ mt: 3, display: 'flex', justifyContent: 'center' })}>
|
||||
<button
|
||||
type="button"
|
||||
onClick={toggleTesting}
|
||||
disabled={classifier.isLoading}
|
||||
className={css({
|
||||
px: 6,
|
||||
py: 2,
|
||||
bg: isRunning ? 'red.600' : 'purple.600',
|
||||
color: 'white',
|
||||
borderRadius: 'md',
|
||||
border: 'none',
|
||||
cursor: 'pointer',
|
||||
fontWeight: 'medium',
|
||||
_hover: { bg: isRunning ? 'red.500' : 'purple.500' },
|
||||
_disabled: { opacity: 0.5, cursor: 'not-allowed' },
|
||||
})}
|
||||
>
|
||||
{classifier.isLoading ? 'Loading...' : isRunning ? 'Stop Testing' : 'Start Testing'}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Calibration status */}
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500', mt: 3, textAlign: 'center' })}>
|
||||
{calibration ? (
|
||||
<span className={css({ color: 'green.400' })}>✓ Markers detected</span>
|
||||
) : (
|
||||
<span>Point camera at abacus with markers</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
'use client'
|
||||
|
||||
import { useState } from 'react'
|
||||
import { css } from '../../../../../../../styled-system/css'
|
||||
import { ModelTester } from '../../ModelTester'
|
||||
import type { TrainingResult } from '../types'
|
||||
|
||||
interface ResultsCardProps {
|
||||
|
|
@ -10,6 +12,8 @@ interface ResultsCardProps {
|
|||
}
|
||||
|
||||
export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
|
||||
const [activeTab, setActiveTab] = useState<'results' | 'test'>('results')
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className={css({ textAlign: 'center', py: 4 })}>
|
||||
|
|
@ -73,87 +77,144 @@ export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
|
|||
const accuracy = result.final_accuracy ?? 0
|
||||
|
||||
return (
|
||||
<div className={css({ textAlign: 'center' })}>
|
||||
{/* Success icon */}
|
||||
<div className={css({ fontSize: '3xl', mb: 2 })}>🎉</div>
|
||||
|
||||
{/* Title */}
|
||||
<div className={css({ fontSize: 'lg', fontWeight: 'bold', color: 'green.400', mb: 3 })}>
|
||||
Training Complete!
|
||||
</div>
|
||||
|
||||
{/* Main accuracy */}
|
||||
<div>
|
||||
{/* Tabs */}
|
||||
<div
|
||||
className={css({
|
||||
fontSize: '4xl',
|
||||
fontWeight: 'bold',
|
||||
color: 'green.400',
|
||||
mb: 0.5,
|
||||
})}
|
||||
>
|
||||
{(accuracy * 100).toFixed(1)}%
|
||||
</div>
|
||||
<div className={css({ fontSize: 'sm', color: 'gray.500', mb: 4 })}>Final Accuracy</div>
|
||||
|
||||
{/* Stats grid */}
|
||||
<div
|
||||
className={css({
|
||||
display: 'grid',
|
||||
gridTemplateColumns: 'repeat(3, 1fr)',
|
||||
gap: 2,
|
||||
p: 3,
|
||||
bg: 'gray.900',
|
||||
borderRadius: 'lg',
|
||||
fontSize: 'sm',
|
||||
display: 'flex',
|
||||
gap: 1,
|
||||
mb: 4,
|
||||
borderBottom: '1px solid',
|
||||
borderColor: 'gray.700',
|
||||
})}
|
||||
>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Epochs</div>
|
||||
<div className={css({ fontFamily: 'mono', color: 'gray.300', fontWeight: 'medium' })}>
|
||||
{result.epochs_trained ?? '—'}
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Final Loss</div>
|
||||
<div className={css({ fontFamily: 'mono', color: 'gray.300', fontWeight: 'medium' })}>
|
||||
{result.final_loss?.toFixed(4) ?? '—'}
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Model</div>
|
||||
<div className={css({ color: 'green.400', fontWeight: 'medium' })}>
|
||||
{result.tfjs_exported ? '✓ Saved' : '—'}
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setActiveTab('results')}
|
||||
className={css({
|
||||
px: 4,
|
||||
py: 2,
|
||||
bg: 'transparent',
|
||||
color: activeTab === 'results' ? 'green.400' : 'gray.500',
|
||||
border: 'none',
|
||||
borderBottom: '2px solid',
|
||||
borderColor: activeTab === 'results' ? 'green.400' : 'transparent',
|
||||
cursor: 'pointer',
|
||||
fontWeight: 'medium',
|
||||
fontSize: 'sm',
|
||||
_hover: { color: 'gray.300' },
|
||||
})}
|
||||
>
|
||||
Results
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setActiveTab('test')}
|
||||
className={css({
|
||||
px: 4,
|
||||
py: 2,
|
||||
bg: 'transparent',
|
||||
color: activeTab === 'test' ? 'purple.400' : 'gray.500',
|
||||
border: 'none',
|
||||
borderBottom: '2px solid',
|
||||
borderColor: activeTab === 'test' ? 'purple.400' : 'transparent',
|
||||
cursor: 'pointer',
|
||||
fontWeight: 'medium',
|
||||
fontSize: 'sm',
|
||||
_hover: { color: 'gray.300' },
|
||||
})}
|
||||
>
|
||||
🔬 Test Model
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Model path */}
|
||||
{result.tfjs_exported && (
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500', mb: 4 })}>
|
||||
Model exported and ready to use
|
||||
</div>
|
||||
)}
|
||||
{/* Tab content */}
|
||||
{activeTab === 'results' ? (
|
||||
<div className={css({ textAlign: 'center' })}>
|
||||
{/* Success icon */}
|
||||
<div className={css({ fontSize: '3xl', mb: 2 })}>🎉</div>
|
||||
|
||||
{/* Train again button */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={onTrainAgain}
|
||||
className={css({
|
||||
px: 6,
|
||||
py: 3,
|
||||
bg: 'blue.600',
|
||||
color: 'white',
|
||||
borderRadius: 'lg',
|
||||
border: 'none',
|
||||
cursor: 'pointer',
|
||||
fontWeight: 'bold',
|
||||
fontSize: 'md',
|
||||
_hover: { bg: 'blue.500' },
|
||||
})}
|
||||
>
|
||||
Train Again
|
||||
</button>
|
||||
{/* Title */}
|
||||
<div className={css({ fontSize: 'lg', fontWeight: 'bold', color: 'green.400', mb: 3 })}>
|
||||
Training Complete!
|
||||
</div>
|
||||
|
||||
{/* Main accuracy */}
|
||||
<div
|
||||
className={css({
|
||||
fontSize: '4xl',
|
||||
fontWeight: 'bold',
|
||||
color: 'green.400',
|
||||
mb: 0.5,
|
||||
})}
|
||||
>
|
||||
{(accuracy * 100).toFixed(1)}%
|
||||
</div>
|
||||
<div className={css({ fontSize: 'sm', color: 'gray.500', mb: 4 })}>Final Accuracy</div>
|
||||
|
||||
{/* Stats grid */}
|
||||
<div
|
||||
className={css({
|
||||
display: 'grid',
|
||||
gridTemplateColumns: 'repeat(3, 1fr)',
|
||||
gap: 2,
|
||||
p: 3,
|
||||
bg: 'gray.900',
|
||||
borderRadius: 'lg',
|
||||
fontSize: 'sm',
|
||||
mb: 4,
|
||||
})}
|
||||
>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Epochs</div>
|
||||
<div className={css({ fontFamily: 'mono', color: 'gray.300', fontWeight: 'medium' })}>
|
||||
{result.epochs_trained ?? '—'}
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Final Loss</div>
|
||||
<div className={css({ fontFamily: 'mono', color: 'gray.300', fontWeight: 'medium' })}>
|
||||
{result.final_loss?.toFixed(4) ?? '—'}
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Model</div>
|
||||
<div className={css({ color: 'green.400', fontWeight: 'medium' })}>
|
||||
{result.tfjs_exported ? '✓ Saved' : '—'}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Model path */}
|
||||
{result.tfjs_exported && (
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500', mb: 4 })}>
|
||||
Model exported and ready to use
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Train again button */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={onTrainAgain}
|
||||
className={css({
|
||||
px: 6,
|
||||
py: 3,
|
||||
bg: 'blue.600',
|
||||
color: 'white',
|
||||
borderRadius: 'lg',
|
||||
border: 'none',
|
||||
cursor: 'pointer',
|
||||
fontWeight: 'bold',
|
||||
fontSize: 'md',
|
||||
_hover: { bg: 'blue.500' },
|
||||
})}
|
||||
>
|
||||
Train Again
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<ModelTester columnCount={4} />
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue