feat(vision): add remediation advice when training produces poor results

When training accuracy is low (<50%), the results card now shows
a yellow warning box explaining why accuracy might be poor and
what actions to take:

- Data imbalance: identifies underrepresented digits
- Insufficient data: recommends collecting more images
- Poor convergence: suggests checking data quality
- Unknown issues: fallback advice

Uses a context provider pattern to avoid prop drilling through
5 component levels (page → TrainingWizard → PhaseSection →
CardCarousel → ExpandedCard → ResultsCard).

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Thomas Hallock 2026-01-06 17:06:56 -06:00
parent 2bcdceef59
commit e9bd3b3c61
4 changed files with 262 additions and 22 deletions

View File

@ -0,0 +1,163 @@
'use client'
import { createContext, useContext, useMemo, type ReactNode } from 'react'
import type {
SamplesData,
DatasetInfo,
EpochData,
TrainingConfig,
TrainingResult,
} from './wizard/types'
export interface DiagnosticReason {
type: 'imbalance' | 'insufficient-data' | 'poor-convergence' | 'unknown'
severity: 'warning' | 'error'
title: string
description: string
action: string
details?: {
underrepresented?: number[]
minCount?: number
maxCount?: number
totalImages?: number
}
}
interface TrainingDiagnostics {
// Raw data
samples: SamplesData | null
datasetInfo: DatasetInfo | null
epochHistory: EpochData[]
config: TrainingConfig
result: TrainingResult | null
// Computed diagnostics
shouldShowRemediation: boolean
reasons: DiagnosticReason[]
}
const TrainingDiagnosticsContext = createContext<TrainingDiagnostics | null>(null)
export function useTrainingDiagnostics(): TrainingDiagnostics {
const ctx = useContext(TrainingDiagnosticsContext)
if (!ctx) {
throw new Error('useTrainingDiagnostics must be used within TrainingDiagnosticsProvider')
}
return ctx
}
interface ProviderProps {
samples: SamplesData | null
datasetInfo: DatasetInfo | null
epochHistory: EpochData[]
config: TrainingConfig
result: TrainingResult | null
children: ReactNode
}
export function TrainingDiagnosticsProvider({
samples,
datasetInfo,
epochHistory,
config,
result,
children,
}: ProviderProps) {
const diagnostics = useMemo(() => {
const reasons = analyzeDiagnostics(result, samples, datasetInfo, epochHistory)
const accuracy = result?.final_accuracy ?? 0
return {
samples,
datasetInfo,
epochHistory,
config,
result,
shouldShowRemediation: accuracy < 0.5 || (accuracy < 0.7 && reasons.length > 0),
reasons,
}
}, [samples, datasetInfo, epochHistory, config, result])
return (
<TrainingDiagnosticsContext.Provider value={diagnostics}>
{children}
</TrainingDiagnosticsContext.Provider>
)
}
function analyzeDiagnostics(
result: TrainingResult | null,
samples: SamplesData | null,
datasetInfo: DatasetInfo | null,
epochHistory: EpochData[]
): DiagnosticReason[] {
if (!result) return []
const accuracy = result.final_accuracy
const reasons: DiagnosticReason[] = []
// 1. Check for data imbalance
if (samples) {
const counts = Object.values(samples.digits).map((d) => d.count)
const max = Math.max(...counts)
const min = Math.min(...counts)
if (max > min * 5 && min < 10) {
const underrepresented = Object.entries(samples.digits)
.filter(([, d]) => d.count < max / 3)
.map(([digit]) => parseInt(digit, 10))
reasons.push({
type: 'imbalance',
severity: 'error',
title: 'Data imbalance',
description: `Some digits have very few samples (${min}) while others have many (${max})`,
action: `Collect more samples for digits: ${underrepresented.join(', ')}`,
details: { underrepresented, minCount: min, maxCount: max },
})
}
}
// 2. Check for insufficient total data
const total = datasetInfo?.total_images ?? samples?.totalImages ?? 0
if (total < 200) {
reasons.push({
type: 'insufficient-data',
severity: total < 100 ? 'error' : 'warning',
title: 'Insufficient training data',
description: `Only ${total} images available`,
action: 'Collect at least 200 images (20+ per digit)',
details: { totalImages: total },
})
}
// 3. Check for poor convergence (accuracy barely improved during training)
if (epochHistory.length >= 2) {
const firstAcc = epochHistory[0]?.val_accuracy ?? 0
const lastAcc = epochHistory[epochHistory.length - 1]?.val_accuracy ?? 0
const improvement = lastAcc - firstAcc
if (improvement < 0.1 && accuracy < 0.5) {
reasons.push({
type: 'poor-convergence',
severity: 'warning',
title: 'Model failed to learn',
description: 'Accuracy barely improved during training',
action: 'Check data quality - images may be too noisy or inconsistent',
})
}
}
// 4. Unknown issue if accuracy is bad but no clear reason
if (accuracy < 0.5 && reasons.length === 0) {
reasons.push({
type: 'unknown',
severity: 'warning',
title: 'Unexpected low accuracy',
description: 'Data appears adequate but accuracy is poor',
action: 'Try training again or review captured images for quality issues',
})
}
return reasons
}

View File

@ -141,7 +141,14 @@ export function ExpandedCard({
case 'export':
return <ExportCard message={statusMessage} />
case 'results':
return <ResultsCard result={result} error={error} onTrainAgain={onTrainAgain} />
return (
<ResultsCard
result={result}
error={error}
configuredEpochs={config.epochs}
onTrainAgain={onTrainAgain}
/>
)
default:
return null
}

View File

@ -3,16 +3,19 @@
import { useState } from 'react'
import { css } from '../../../../../../../styled-system/css'
import { ModelTester } from '../../ModelTester'
import { useTrainingDiagnostics, type DiagnosticReason } from '../../TrainingDiagnosticsContext'
import type { TrainingResult } from '../types'
interface ResultsCardProps {
result: TrainingResult | null
error: string | null
configuredEpochs: number
onTrainAgain: () => void
}
export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
export function ResultsCard({ result, error, configuredEpochs, onTrainAgain }: ResultsCardProps) {
const [activeTab, setActiveTab] = useState<'results' | 'test'>('results')
const diagnostics = useTrainingDiagnostics()
if (error) {
return (
@ -169,7 +172,15 @@ export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Epochs</div>
<div className={css({ fontFamily: 'mono', color: 'gray.300', fontWeight: 'medium' })}>
{result.epochs_trained ?? '—'}
{result.epochs_trained && result.epochs_trained < configuredEpochs && (
<span className={css({ color: 'gray.500' })}>/{configuredEpochs}</span>
)}
</div>
{result.epochs_trained && result.epochs_trained < configuredEpochs && (
<div className={css({ color: 'green.500', fontSize: 'xs', fontWeight: 'medium' })}>
converged
</div>
)}
</div>
<div>
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Final Loss</div>
@ -192,6 +203,11 @@ export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
</div>
)}
{/* Remediation advice for poor results */}
{diagnostics.shouldShowRemediation && (
<RemediationSection reasons={diagnostics.reasons} />
)}
{/* Train again button */}
<button
type="button"
@ -218,3 +234,48 @@ export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
</div>
)
}
/** Shows diagnostic reasons and remediation advice for poor training results */
function RemediationSection({ reasons }: { reasons: DiagnosticReason[] }) {
if (reasons.length === 0) return null
return (
<div
data-element="remediation-section"
className={css({
mt: 4,
mb: 4,
p: 3,
bg: 'yellow.900/30',
border: '1px solid',
borderColor: 'yellow.700',
borderRadius: 'lg',
textAlign: 'left',
})}
>
<div
className={css({
display: 'flex',
alignItems: 'center',
gap: 2,
mb: 2,
fontWeight: 'semibold',
color: 'yellow.400',
fontSize: 'sm',
})}
>
<span></span>
<span>Why is accuracy low?</span>
</div>
<div className={css({ display: 'flex', flexDirection: 'column', gap: 2 })}>
{reasons.map((reason, i) => (
<div key={i} className={css({ fontSize: 'xs' })}>
<div className={css({ color: 'gray.200', fontWeight: 'medium' })}> {reason.title}</div>
<div className={css({ color: 'gray.400', ml: 3 })}>{reason.action}</div>
</div>
))}
</div>
</div>
)
}

View File

@ -3,6 +3,7 @@
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import Link from 'next/link'
import { css } from '../../../../styled-system/css'
import { TrainingDiagnosticsProvider } from './components/TrainingDiagnosticsContext'
import { TrainingWizard } from './components/wizard/TrainingWizard'
import type {
SamplesData,
@ -446,29 +447,37 @@ export default function TrainModelPage() {
</div>
{/* Training Wizard - handles all phases */}
<TrainingWizard
<TrainingDiagnosticsProvider
samples={samples}
samplesLoading={samplesLoading}
hardwareInfo={hardwareInfo}
hardwareLoading={hardwareLoading}
fetchHardware={fetchHardware}
preflightInfo={preflightInfo}
preflightLoading={preflightLoading}
fetchPreflight={fetchPreflight}
config={config}
setConfig={setConfig}
serverPhase={serverPhase}
statusMessage={statusMessage}
currentEpoch={currentEpoch}
epochHistory={epochHistory}
datasetInfo={datasetInfo}
epochHistory={epochHistory}
config={config}
result={result}
error={error}
onStart={startTraining}
onCancel={cancelTraining}
onReset={resetToIdle}
onSyncComplete={fetchSamples}
/>
>
<TrainingWizard
samples={samples}
samplesLoading={samplesLoading}
hardwareInfo={hardwareInfo}
hardwareLoading={hardwareLoading}
fetchHardware={fetchHardware}
preflightInfo={preflightInfo}
preflightLoading={preflightLoading}
fetchPreflight={fetchPreflight}
config={config}
setConfig={setConfig}
serverPhase={serverPhase}
statusMessage={statusMessage}
currentEpoch={currentEpoch}
epochHistory={epochHistory}
datasetInfo={datasetInfo}
result={result}
error={error}
onStart={startTraining}
onCancel={cancelTraining}
onReset={resetToIdle}
onSyncComplete={fetchSamples}
/>
</TrainingDiagnosticsProvider>
</main>
</div>
)