feat(vision): add remediation advice when training produces poor results
When training accuracy is low (<50%), the results card now shows a yellow warning box explaining why accuracy might be poor and what actions to take: - Data imbalance: identifies underrepresented digits - Insufficient data: recommends collecting more images - Poor convergence: suggests checking data quality - Unknown issues: fallback advice Uses a context provider pattern to avoid prop drilling through 5 component levels (page → TrainingWizard → PhaseSection → CardCarousel → ExpandedCard → ResultsCard). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
2bcdceef59
commit
e9bd3b3c61
|
|
@ -0,0 +1,163 @@
|
|||
'use client'
|
||||
|
||||
import { createContext, useContext, useMemo, type ReactNode } from 'react'
|
||||
import type {
|
||||
SamplesData,
|
||||
DatasetInfo,
|
||||
EpochData,
|
||||
TrainingConfig,
|
||||
TrainingResult,
|
||||
} from './wizard/types'
|
||||
|
||||
export interface DiagnosticReason {
|
||||
type: 'imbalance' | 'insufficient-data' | 'poor-convergence' | 'unknown'
|
||||
severity: 'warning' | 'error'
|
||||
title: string
|
||||
description: string
|
||||
action: string
|
||||
details?: {
|
||||
underrepresented?: number[]
|
||||
minCount?: number
|
||||
maxCount?: number
|
||||
totalImages?: number
|
||||
}
|
||||
}
|
||||
|
||||
interface TrainingDiagnostics {
|
||||
// Raw data
|
||||
samples: SamplesData | null
|
||||
datasetInfo: DatasetInfo | null
|
||||
epochHistory: EpochData[]
|
||||
config: TrainingConfig
|
||||
result: TrainingResult | null
|
||||
|
||||
// Computed diagnostics
|
||||
shouldShowRemediation: boolean
|
||||
reasons: DiagnosticReason[]
|
||||
}
|
||||
|
||||
const TrainingDiagnosticsContext = createContext<TrainingDiagnostics | null>(null)
|
||||
|
||||
export function useTrainingDiagnostics(): TrainingDiagnostics {
|
||||
const ctx = useContext(TrainingDiagnosticsContext)
|
||||
if (!ctx) {
|
||||
throw new Error('useTrainingDiagnostics must be used within TrainingDiagnosticsProvider')
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
interface ProviderProps {
|
||||
samples: SamplesData | null
|
||||
datasetInfo: DatasetInfo | null
|
||||
epochHistory: EpochData[]
|
||||
config: TrainingConfig
|
||||
result: TrainingResult | null
|
||||
children: ReactNode
|
||||
}
|
||||
|
||||
export function TrainingDiagnosticsProvider({
|
||||
samples,
|
||||
datasetInfo,
|
||||
epochHistory,
|
||||
config,
|
||||
result,
|
||||
children,
|
||||
}: ProviderProps) {
|
||||
const diagnostics = useMemo(() => {
|
||||
const reasons = analyzeDiagnostics(result, samples, datasetInfo, epochHistory)
|
||||
const accuracy = result?.final_accuracy ?? 0
|
||||
|
||||
return {
|
||||
samples,
|
||||
datasetInfo,
|
||||
epochHistory,
|
||||
config,
|
||||
result,
|
||||
shouldShowRemediation: accuracy < 0.5 || (accuracy < 0.7 && reasons.length > 0),
|
||||
reasons,
|
||||
}
|
||||
}, [samples, datasetInfo, epochHistory, config, result])
|
||||
|
||||
return (
|
||||
<TrainingDiagnosticsContext.Provider value={diagnostics}>
|
||||
{children}
|
||||
</TrainingDiagnosticsContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
function analyzeDiagnostics(
|
||||
result: TrainingResult | null,
|
||||
samples: SamplesData | null,
|
||||
datasetInfo: DatasetInfo | null,
|
||||
epochHistory: EpochData[]
|
||||
): DiagnosticReason[] {
|
||||
if (!result) return []
|
||||
|
||||
const accuracy = result.final_accuracy
|
||||
const reasons: DiagnosticReason[] = []
|
||||
|
||||
// 1. Check for data imbalance
|
||||
if (samples) {
|
||||
const counts = Object.values(samples.digits).map((d) => d.count)
|
||||
const max = Math.max(...counts)
|
||||
const min = Math.min(...counts)
|
||||
|
||||
if (max > min * 5 && min < 10) {
|
||||
const underrepresented = Object.entries(samples.digits)
|
||||
.filter(([, d]) => d.count < max / 3)
|
||||
.map(([digit]) => parseInt(digit, 10))
|
||||
|
||||
reasons.push({
|
||||
type: 'imbalance',
|
||||
severity: 'error',
|
||||
title: 'Data imbalance',
|
||||
description: `Some digits have very few samples (${min}) while others have many (${max})`,
|
||||
action: `Collect more samples for digits: ${underrepresented.join(', ')}`,
|
||||
details: { underrepresented, minCount: min, maxCount: max },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check for insufficient total data
|
||||
const total = datasetInfo?.total_images ?? samples?.totalImages ?? 0
|
||||
if (total < 200) {
|
||||
reasons.push({
|
||||
type: 'insufficient-data',
|
||||
severity: total < 100 ? 'error' : 'warning',
|
||||
title: 'Insufficient training data',
|
||||
description: `Only ${total} images available`,
|
||||
action: 'Collect at least 200 images (20+ per digit)',
|
||||
details: { totalImages: total },
|
||||
})
|
||||
}
|
||||
|
||||
// 3. Check for poor convergence (accuracy barely improved during training)
|
||||
if (epochHistory.length >= 2) {
|
||||
const firstAcc = epochHistory[0]?.val_accuracy ?? 0
|
||||
const lastAcc = epochHistory[epochHistory.length - 1]?.val_accuracy ?? 0
|
||||
const improvement = lastAcc - firstAcc
|
||||
|
||||
if (improvement < 0.1 && accuracy < 0.5) {
|
||||
reasons.push({
|
||||
type: 'poor-convergence',
|
||||
severity: 'warning',
|
||||
title: 'Model failed to learn',
|
||||
description: 'Accuracy barely improved during training',
|
||||
action: 'Check data quality - images may be too noisy or inconsistent',
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Unknown issue if accuracy is bad but no clear reason
|
||||
if (accuracy < 0.5 && reasons.length === 0) {
|
||||
reasons.push({
|
||||
type: 'unknown',
|
||||
severity: 'warning',
|
||||
title: 'Unexpected low accuracy',
|
||||
description: 'Data appears adequate but accuracy is poor',
|
||||
action: 'Try training again or review captured images for quality issues',
|
||||
})
|
||||
}
|
||||
|
||||
return reasons
|
||||
}
|
||||
|
|
@ -141,7 +141,14 @@ export function ExpandedCard({
|
|||
case 'export':
|
||||
return <ExportCard message={statusMessage} />
|
||||
case 'results':
|
||||
return <ResultsCard result={result} error={error} onTrainAgain={onTrainAgain} />
|
||||
return (
|
||||
<ResultsCard
|
||||
result={result}
|
||||
error={error}
|
||||
configuredEpochs={config.epochs}
|
||||
onTrainAgain={onTrainAgain}
|
||||
/>
|
||||
)
|
||||
default:
|
||||
return null
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,16 +3,19 @@
|
|||
import { useState } from 'react'
|
||||
import { css } from '../../../../../../../styled-system/css'
|
||||
import { ModelTester } from '../../ModelTester'
|
||||
import { useTrainingDiagnostics, type DiagnosticReason } from '../../TrainingDiagnosticsContext'
|
||||
import type { TrainingResult } from '../types'
|
||||
|
||||
interface ResultsCardProps {
|
||||
result: TrainingResult | null
|
||||
error: string | null
|
||||
configuredEpochs: number
|
||||
onTrainAgain: () => void
|
||||
}
|
||||
|
||||
export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
|
||||
export function ResultsCard({ result, error, configuredEpochs, onTrainAgain }: ResultsCardProps) {
|
||||
const [activeTab, setActiveTab] = useState<'results' | 'test'>('results')
|
||||
const diagnostics = useTrainingDiagnostics()
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
|
|
@ -169,7 +172,15 @@ export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
|
|||
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Epochs</div>
|
||||
<div className={css({ fontFamily: 'mono', color: 'gray.300', fontWeight: 'medium' })}>
|
||||
{result.epochs_trained ?? '—'}
|
||||
{result.epochs_trained && result.epochs_trained < configuredEpochs && (
|
||||
<span className={css({ color: 'gray.500' })}>/{configuredEpochs}</span>
|
||||
)}
|
||||
</div>
|
||||
{result.epochs_trained && result.epochs_trained < configuredEpochs && (
|
||||
<div className={css({ color: 'green.500', fontSize: 'xs', fontWeight: 'medium' })}>
|
||||
converged
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Final Loss</div>
|
||||
|
|
@ -192,6 +203,11 @@ export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
|
|||
</div>
|
||||
)}
|
||||
|
||||
{/* Remediation advice for poor results */}
|
||||
{diagnostics.shouldShowRemediation && (
|
||||
<RemediationSection reasons={diagnostics.reasons} />
|
||||
)}
|
||||
|
||||
{/* Train again button */}
|
||||
<button
|
||||
type="button"
|
||||
|
|
@ -218,3 +234,48 @@ export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
|
|||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
/** Shows diagnostic reasons and remediation advice for poor training results */
|
||||
function RemediationSection({ reasons }: { reasons: DiagnosticReason[] }) {
|
||||
if (reasons.length === 0) return null
|
||||
|
||||
return (
|
||||
<div
|
||||
data-element="remediation-section"
|
||||
className={css({
|
||||
mt: 4,
|
||||
mb: 4,
|
||||
p: 3,
|
||||
bg: 'yellow.900/30',
|
||||
border: '1px solid',
|
||||
borderColor: 'yellow.700',
|
||||
borderRadius: 'lg',
|
||||
textAlign: 'left',
|
||||
})}
|
||||
>
|
||||
<div
|
||||
className={css({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 2,
|
||||
mb: 2,
|
||||
fontWeight: 'semibold',
|
||||
color: 'yellow.400',
|
||||
fontSize: 'sm',
|
||||
})}
|
||||
>
|
||||
<span>⚠️</span>
|
||||
<span>Why is accuracy low?</span>
|
||||
</div>
|
||||
|
||||
<div className={css({ display: 'flex', flexDirection: 'column', gap: 2 })}>
|
||||
{reasons.map((reason, i) => (
|
||||
<div key={i} className={css({ fontSize: 'xs' })}>
|
||||
<div className={css({ color: 'gray.200', fontWeight: 'medium' })}>• {reason.title}</div>
|
||||
<div className={css({ color: 'gray.400', ml: 3 })}>{reason.action}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import Link from 'next/link'
|
||||
import { css } from '../../../../styled-system/css'
|
||||
import { TrainingDiagnosticsProvider } from './components/TrainingDiagnosticsContext'
|
||||
import { TrainingWizard } from './components/wizard/TrainingWizard'
|
||||
import type {
|
||||
SamplesData,
|
||||
|
|
@ -446,29 +447,37 @@ export default function TrainModelPage() {
|
|||
</div>
|
||||
|
||||
{/* Training Wizard - handles all phases */}
|
||||
<TrainingWizard
|
||||
<TrainingDiagnosticsProvider
|
||||
samples={samples}
|
||||
samplesLoading={samplesLoading}
|
||||
hardwareInfo={hardwareInfo}
|
||||
hardwareLoading={hardwareLoading}
|
||||
fetchHardware={fetchHardware}
|
||||
preflightInfo={preflightInfo}
|
||||
preflightLoading={preflightLoading}
|
||||
fetchPreflight={fetchPreflight}
|
||||
config={config}
|
||||
setConfig={setConfig}
|
||||
serverPhase={serverPhase}
|
||||
statusMessage={statusMessage}
|
||||
currentEpoch={currentEpoch}
|
||||
epochHistory={epochHistory}
|
||||
datasetInfo={datasetInfo}
|
||||
epochHistory={epochHistory}
|
||||
config={config}
|
||||
result={result}
|
||||
error={error}
|
||||
onStart={startTraining}
|
||||
onCancel={cancelTraining}
|
||||
onReset={resetToIdle}
|
||||
onSyncComplete={fetchSamples}
|
||||
/>
|
||||
>
|
||||
<TrainingWizard
|
||||
samples={samples}
|
||||
samplesLoading={samplesLoading}
|
||||
hardwareInfo={hardwareInfo}
|
||||
hardwareLoading={hardwareLoading}
|
||||
fetchHardware={fetchHardware}
|
||||
preflightInfo={preflightInfo}
|
||||
preflightLoading={preflightLoading}
|
||||
fetchPreflight={fetchPreflight}
|
||||
config={config}
|
||||
setConfig={setConfig}
|
||||
serverPhase={serverPhase}
|
||||
statusMessage={statusMessage}
|
||||
currentEpoch={currentEpoch}
|
||||
epochHistory={epochHistory}
|
||||
datasetInfo={datasetInfo}
|
||||
result={result}
|
||||
error={error}
|
||||
onStart={startTraining}
|
||||
onCancel={cancelTraining}
|
||||
onReset={resetToIdle}
|
||||
onSyncComplete={fetchSamples}
|
||||
/>
|
||||
</TrainingDiagnosticsProvider>
|
||||
</main>
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue