feat(vision): wizard-style training UI with production sync
Transform training pipeline into a phase-based wizard stepper: - Preparation phase: data check, hardware detection, config - Training phase: setup, loading, training progress, export - Results phase: final accuracy and train again option Add sync UI to pull training images from production NAS: - SSE-based progress updates during rsync - Shows remote vs local image counts - Skip option for users who don't need sync New wizard components: - TrainingWizard: main orchestrator with phase management - PhaseSection: collapsible phase containers - CardCarousel: left/center/right card positioning - CollapsedCard: compact done/upcoming cards with rich previews - ExpandedCard: full card with content for current step - Individual card components for each step APIs added: - /api/vision-training/samples: check training data status - /api/vision-training/sync: rsync from production with SSE 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -57,8 +57,7 @@ interface SetupResult {
|
||||
* On Apple Silicon, prefer Homebrew ARM Python for Metal GPU support.
|
||||
*/
|
||||
function findBestPython(): string {
|
||||
const isAppleSilicon =
|
||||
process.platform === 'darwin' && process.arch === 'arm64'
|
||||
const isAppleSilicon = process.platform === 'darwin' && process.arch === 'arm64'
|
||||
|
||||
if (isAppleSilicon) {
|
||||
// Try Homebrew Python versions (ARM native)
|
||||
@@ -110,8 +109,7 @@ async function isVenvReady(): Promise<boolean> {
|
||||
* Create the venv and install dependencies
|
||||
*/
|
||||
async function createVenv(): Promise<SetupResult> {
|
||||
const isAppleSilicon =
|
||||
process.platform === 'darwin' && process.arch === 'arm64'
|
||||
const isAppleSilicon = process.platform === 'darwin' && process.arch === 'arm64'
|
||||
|
||||
const basePython = findBestPython()
|
||||
console.log(`[vision-training] Creating venv with ${basePython}...`)
|
||||
@@ -147,9 +145,7 @@ async function createVenv(): Promise<SetupResult> {
|
||||
)
|
||||
const gpuCount = parseInt(stdout.trim(), 10) || 0
|
||||
|
||||
console.log(
|
||||
`[vision-training] Setup complete. GPUs detected: ${gpuCount}`
|
||||
)
|
||||
console.log(`[vision-training] Setup complete. GPUs detected: ${gpuCount}`)
|
||||
|
||||
return {
|
||||
success: true,
|
||||
@@ -185,8 +181,7 @@ export async function ensureVenvReady(): Promise<SetupResult> {
|
||||
|
||||
// Check if already set up
|
||||
if (await isVenvReady()) {
|
||||
const isAppleSilicon =
|
||||
process.platform === 'darwin' && process.arch === 'arm64'
|
||||
const isAppleSilicon = process.platform === 'darwin' && process.arch === 'arm64'
|
||||
|
||||
// Quick GPU check
|
||||
let hasGpu = false
|
||||
|
||||
113
apps/web/src/app/api/vision-training/samples/route.ts
Normal file
113
apps/web/src/app/api/vision-training/samples/route.ts
Normal file
@@ -0,0 +1,113 @@
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
|
||||
const DATA_DIR = path.join(process.cwd(), 'data/vision-training/collected')
|
||||
|
||||
interface DigitSample {
|
||||
count: number
|
||||
samplePath: string | null
|
||||
// For background tiling - random selection of image paths
|
||||
tilePaths: string[]
|
||||
}
|
||||
|
||||
interface SamplesResponse {
|
||||
digits: Record<number, DigitSample>
|
||||
totalImages: number
|
||||
hasData: boolean
|
||||
dataQuality: 'none' | 'insufficient' | 'minimal' | 'good' | 'excellent'
|
||||
}
|
||||
|
||||
/**
|
||||
* GET /api/vision-training/samples
|
||||
*
|
||||
* Returns sample images and counts for each digit (0-9).
|
||||
* Used for the training preview UI.
|
||||
*/
|
||||
export async function GET(): Promise<Response> {
|
||||
const digits: Record<number, DigitSample> = {}
|
||||
let totalImages = 0
|
||||
|
||||
// Initialize all digits
|
||||
for (let d = 0; d <= 9; d++) {
|
||||
digits[d] = { count: 0, samplePath: null, tilePaths: [] }
|
||||
}
|
||||
|
||||
try {
|
||||
// Check if data directory exists
|
||||
if (!fs.existsSync(DATA_DIR)) {
|
||||
return Response.json({
|
||||
digits,
|
||||
totalImages: 0,
|
||||
hasData: false,
|
||||
dataQuality: 'none',
|
||||
} satisfies SamplesResponse)
|
||||
}
|
||||
|
||||
// Scan each digit directory
|
||||
for (let d = 0; d <= 9; d++) {
|
||||
const digitDir = path.join(DATA_DIR, String(d))
|
||||
|
||||
if (!fs.existsSync(digitDir)) {
|
||||
continue
|
||||
}
|
||||
|
||||
const files = fs
|
||||
.readdirSync(digitDir)
|
||||
.filter((f) => /\.(png|jpg|jpeg|webp)$/i.test(f))
|
||||
.sort() // Consistent ordering
|
||||
|
||||
const count = files.length
|
||||
totalImages += count
|
||||
|
||||
if (count > 0) {
|
||||
// Pick a representative sample (middle of the list for variety)
|
||||
const sampleIndex = Math.floor(count / 2)
|
||||
const sampleFile = files[sampleIndex]
|
||||
|
||||
// Pick random files for background tiling (up to 5 per digit)
|
||||
const tileCount = Math.min(5, count)
|
||||
const tileIndices = new Set<number>()
|
||||
while (tileIndices.size < tileCount) {
|
||||
tileIndices.add(Math.floor(Math.random() * count))
|
||||
}
|
||||
const tilePaths = Array.from(tileIndices).map(
|
||||
(i) => `/api/vision-training/images/${d}/${files[i]}`
|
||||
)
|
||||
|
||||
digits[d] = {
|
||||
count,
|
||||
samplePath: `/api/vision-training/images/${d}/${sampleFile}`,
|
||||
tilePaths,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine data quality based on total and distribution
|
||||
let dataQuality: SamplesResponse['dataQuality'] = 'none'
|
||||
const digitCounts = Object.values(digits).map((d) => d.count)
|
||||
const minCount = Math.min(...digitCounts)
|
||||
const avgCount = totalImages / 10
|
||||
|
||||
if (totalImages === 0) {
|
||||
dataQuality = 'none'
|
||||
} else if (totalImages < 50 || minCount < 3) {
|
||||
dataQuality = 'insufficient'
|
||||
} else if (totalImages < 200 || minCount < 10) {
|
||||
dataQuality = 'minimal'
|
||||
} else if (totalImages < 500 || avgCount < 40) {
|
||||
dataQuality = 'good'
|
||||
} else {
|
||||
dataQuality = 'excellent'
|
||||
}
|
||||
|
||||
return Response.json({
|
||||
digits,
|
||||
totalImages,
|
||||
hasData: totalImages > 0,
|
||||
dataQuality,
|
||||
} satisfies SamplesResponse)
|
||||
} catch (error) {
|
||||
console.error('[vision-training/samples] Error:', error)
|
||||
return Response.json({ error: 'Failed to read training samples' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
188
apps/web/src/app/api/vision-training/sync/route.ts
Normal file
188
apps/web/src/app/api/vision-training/sync/route.ts
Normal file
@@ -0,0 +1,188 @@
|
||||
import { spawn } from 'child_process'
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
|
||||
// Configuration - should match the sync script
|
||||
const REMOTE_HOST = 'nas.home.network'
|
||||
const REMOTE_USER = 'antialias'
|
||||
const REMOTE_PATH = '/volume1/homes/antialias/projects/abaci.one/data/vision-training/collected/'
|
||||
const LOCAL_PATH = path.join(process.cwd(), 'data/vision-training/collected/')
|
||||
|
||||
/**
|
||||
* POST /api/vision-training/sync
|
||||
* Sync training data from production to local using rsync
|
||||
* Returns SSE stream with progress updates
|
||||
*/
|
||||
export async function POST() {
|
||||
const encoder = new TextEncoder()
|
||||
|
||||
const stream = new ReadableStream({
|
||||
async start(controller) {
|
||||
const send = (event: string, data: Record<string, unknown>) => {
|
||||
controller.enqueue(encoder.encode(`event: ${event}\ndata: ${JSON.stringify(data)}\n\n`))
|
||||
}
|
||||
|
||||
try {
|
||||
// Ensure local directory exists
|
||||
await fs.promises.mkdir(LOCAL_PATH, { recursive: true })
|
||||
|
||||
send('status', { message: 'Connecting to production server...', phase: 'connecting' })
|
||||
|
||||
// Run rsync with progress
|
||||
const rsync = spawn('rsync', [
|
||||
'-avz',
|
||||
'--progress',
|
||||
'--stats',
|
||||
`${REMOTE_USER}@${REMOTE_HOST}:${REMOTE_PATH}`,
|
||||
LOCAL_PATH,
|
||||
])
|
||||
|
||||
let currentFile = ''
|
||||
let filesTransferred = 0
|
||||
let totalBytes = 0
|
||||
|
||||
rsync.stdout.on('data', (data: Buffer) => {
|
||||
const output = data.toString()
|
||||
const lines = output.split('\n')
|
||||
|
||||
for (const line of lines) {
|
||||
// File being transferred (lines that end with file size info)
|
||||
const fileMatch = line.match(/^(\d+\/\d+\.png)/)
|
||||
if (fileMatch) {
|
||||
currentFile = fileMatch[1]
|
||||
}
|
||||
|
||||
// Progress line: " 1,234,567 100% 12.34MB/s 0:00:01"
|
||||
const progressMatch = line.match(/^\s*([\d,]+)\s+(\d+)%/)
|
||||
if (progressMatch) {
|
||||
const bytes = parseInt(progressMatch[1].replace(/,/g, ''), 10)
|
||||
const percent = parseInt(progressMatch[2], 10)
|
||||
totalBytes += bytes
|
||||
|
||||
if (percent === 100) {
|
||||
filesTransferred++
|
||||
}
|
||||
|
||||
send('progress', {
|
||||
currentFile,
|
||||
filesTransferred,
|
||||
bytesTransferred: totalBytes,
|
||||
message: `Syncing: ${currentFile || 'files'}...`,
|
||||
})
|
||||
}
|
||||
|
||||
// Stats at the end
|
||||
const statsMatch = line.match(/Number of regular files transferred:\s*(\d+)/)
|
||||
if (statsMatch) {
|
||||
filesTransferred = parseInt(statsMatch[1], 10)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
rsync.stderr.on('data', (data: Buffer) => {
|
||||
const output = data.toString()
|
||||
// Ignore SSH banner/warnings, only report actual errors
|
||||
if (output.includes('error') || output.includes('failed')) {
|
||||
send('error', { message: output.trim() })
|
||||
}
|
||||
})
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
rsync.on('close', (code) => {
|
||||
if (code === 0) {
|
||||
resolve()
|
||||
} else {
|
||||
reject(new Error(`rsync exited with code ${code}`))
|
||||
}
|
||||
})
|
||||
rsync.on('error', reject)
|
||||
})
|
||||
|
||||
// Count final stats
|
||||
const stats = await countLocalImages()
|
||||
|
||||
send('complete', {
|
||||
filesTransferred,
|
||||
...stats,
|
||||
})
|
||||
} catch (error) {
|
||||
send('error', {
|
||||
message: error instanceof Error ? error.message : 'Sync failed',
|
||||
})
|
||||
} finally {
|
||||
controller.close()
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
return new Response(stream, {
|
||||
headers: {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
Connection: 'keep-alive',
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* GET /api/vision-training/sync
|
||||
* Check if sync is available (SSH connectivity)
|
||||
*/
|
||||
export async function GET() {
|
||||
try {
|
||||
// Quick check if we can reach the remote host
|
||||
const { exec } = await import('child_process')
|
||||
const { promisify } = await import('util')
|
||||
const execAsync = promisify(exec)
|
||||
|
||||
// Test SSH connection with timeout
|
||||
await execAsync(`ssh -o ConnectTimeout=3 -o BatchMode=yes ${REMOTE_USER}@${REMOTE_HOST} "echo ok"`, {
|
||||
timeout: 5000,
|
||||
})
|
||||
|
||||
// Get remote stats
|
||||
const { stdout } = await execAsync(
|
||||
`ssh ${REMOTE_USER}@${REMOTE_HOST} "find '${REMOTE_PATH}' -name '*.png' 2>/dev/null | wc -l"`,
|
||||
{ timeout: 10000 }
|
||||
)
|
||||
const remoteCount = parseInt(stdout.trim(), 10) || 0
|
||||
|
||||
// Get local stats
|
||||
const localStats = await countLocalImages()
|
||||
|
||||
return Response.json({
|
||||
available: true,
|
||||
remote: {
|
||||
host: REMOTE_HOST,
|
||||
totalImages: remoteCount,
|
||||
},
|
||||
local: localStats,
|
||||
needsSync: remoteCount > localStats.totalImages,
|
||||
})
|
||||
} catch {
|
||||
return Response.json({
|
||||
available: false,
|
||||
error: 'Cannot connect to production server',
|
||||
local: await countLocalImages(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async function countLocalImages(): Promise<{ totalImages: number; digitCounts: Record<number, number> }> {
|
||||
const digitCounts: Record<number, number> = {}
|
||||
let totalImages = 0
|
||||
|
||||
for (let digit = 0; digit <= 9; digit++) {
|
||||
const digitPath = path.join(LOCAL_PATH, String(digit))
|
||||
try {
|
||||
const files = await fs.promises.readdir(digitPath)
|
||||
const pngCount = files.filter((f) => f.endsWith('.png')).length
|
||||
digitCounts[digit] = pngCount
|
||||
totalImages += pngCount
|
||||
} catch {
|
||||
digitCounts[digit] = 0
|
||||
}
|
||||
}
|
||||
|
||||
return { totalImages, digitCounts }
|
||||
}
|
||||
@@ -0,0 +1,208 @@
|
||||
'use client'
|
||||
|
||||
import { css } from '../../../../../../styled-system/css'
|
||||
import { CollapsedCard } from './CollapsedCard'
|
||||
import { ExpandedCard } from './ExpandedCard'
|
||||
import {
|
||||
CARDS,
|
||||
type CardId,
|
||||
type SamplesData,
|
||||
type HardwareInfo,
|
||||
type TrainingConfig,
|
||||
type ServerPhase,
|
||||
type EpochData,
|
||||
type DatasetInfo,
|
||||
type TrainingResult,
|
||||
} from './types'
|
||||
|
||||
interface CardCarouselProps {
|
||||
cards: CardId[]
|
||||
currentCardIndex: number
|
||||
// Data
|
||||
samples: SamplesData | null
|
||||
samplesLoading: boolean
|
||||
hardwareInfo: HardwareInfo | null
|
||||
hardwareLoading: boolean
|
||||
fetchHardware: () => void
|
||||
config: TrainingConfig
|
||||
setConfig: (config: TrainingConfig | ((prev: TrainingConfig) => TrainingConfig)) => void
|
||||
isGpu: boolean
|
||||
// Training
|
||||
serverPhase: ServerPhase
|
||||
statusMessage: string
|
||||
currentEpoch: EpochData | null
|
||||
bestAccuracy: number
|
||||
datasetInfo: DatasetInfo | null
|
||||
result: TrainingResult | null
|
||||
error: string | null
|
||||
// Summaries
|
||||
getCardSummary: (cardId: string) => { label: string; value: string } | null
|
||||
// Actions
|
||||
onProgress: () => void
|
||||
onStartTraining: () => void
|
||||
onCancel: () => void
|
||||
onTrainAgain: () => void
|
||||
onSyncComplete?: () => void
|
||||
canStartTraining: boolean
|
||||
}
|
||||
|
||||
export function CardCarousel({
|
||||
cards,
|
||||
currentCardIndex,
|
||||
samples,
|
||||
samplesLoading,
|
||||
hardwareInfo,
|
||||
hardwareLoading,
|
||||
fetchHardware,
|
||||
config,
|
||||
setConfig,
|
||||
isGpu,
|
||||
serverPhase,
|
||||
statusMessage,
|
||||
currentEpoch,
|
||||
bestAccuracy,
|
||||
datasetInfo,
|
||||
result,
|
||||
error,
|
||||
getCardSummary,
|
||||
onProgress,
|
||||
onStartTraining,
|
||||
onCancel,
|
||||
onTrainAgain,
|
||||
onSyncComplete,
|
||||
canStartTraining,
|
||||
}: CardCarouselProps) {
|
||||
// Generate preview for upcoming cards based on known data
|
||||
// Can return a simple string or a rich object with multiple lines
|
||||
const getCardPreview = (cardId: CardId): { primary: string; secondary?: string; tertiary?: string } | string => {
|
||||
switch (cardId) {
|
||||
case 'data':
|
||||
if (samples?.hasData) {
|
||||
return {
|
||||
primary: `${samples.totalImages} images`,
|
||||
secondary: samples.dataQuality === 'excellent' ? 'Excellent' :
|
||||
samples.dataQuality === 'good' ? 'Good quality' :
|
||||
samples.dataQuality === 'minimal' ? 'Minimal' : 'Ready',
|
||||
}
|
||||
}
|
||||
return 'Check data'
|
||||
|
||||
case 'hardware':
|
||||
if (hardwareInfo && !hardwareInfo.error) {
|
||||
const shortName = hardwareInfo.deviceName.length > 12
|
||||
? hardwareInfo.deviceName.split(' ').slice(0, 2).join(' ')
|
||||
: hardwareInfo.deviceName
|
||||
return {
|
||||
primary: shortName,
|
||||
secondary: hardwareInfo.deviceType === 'gpu' ? 'GPU Accel' : 'CPU Mode',
|
||||
}
|
||||
}
|
||||
return 'Detect HW'
|
||||
|
||||
case 'config':
|
||||
return {
|
||||
primary: `${config.epochs} epochs`,
|
||||
secondary: `Batch ${config.batchSize}`,
|
||||
tertiary: config.augmentation ? 'Augment: On' : 'Augment: Off',
|
||||
}
|
||||
|
||||
case 'setup':
|
||||
return 'Initialize'
|
||||
|
||||
case 'loading':
|
||||
return samples?.totalImages ? `Load ${samples.totalImages}` : 'Load data'
|
||||
|
||||
case 'training':
|
||||
return {
|
||||
primary: `${config.epochs} epochs`,
|
||||
secondary: config.augmentation ? 'w/ augment' : undefined,
|
||||
}
|
||||
|
||||
case 'export':
|
||||
return 'TF.js export'
|
||||
|
||||
case 'results':
|
||||
return 'View results'
|
||||
|
||||
default:
|
||||
return ''
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
data-element="card-carousel"
|
||||
className={css({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
gap: 3,
|
||||
mb: 4,
|
||||
})}
|
||||
>
|
||||
{/* Done cards (left side) */}
|
||||
<div className={css({ display: 'flex', gap: 2 })}>
|
||||
{cards.slice(0, currentCardIndex).map((cardId) => {
|
||||
const cardDef = CARDS[cardId]
|
||||
const summary = getCardSummary(cardId)
|
||||
return (
|
||||
<CollapsedCard
|
||||
key={cardId}
|
||||
icon={cardDef.icon}
|
||||
title={cardDef.title}
|
||||
summary={summary?.value}
|
||||
status="done"
|
||||
/>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
|
||||
{/* Current card (center, expanded) */}
|
||||
{currentCardIndex >= 0 && currentCardIndex < cards.length && (
|
||||
<ExpandedCard
|
||||
cardId={cards[currentCardIndex]}
|
||||
// Data
|
||||
samples={samples}
|
||||
samplesLoading={samplesLoading}
|
||||
hardwareInfo={hardwareInfo}
|
||||
hardwareLoading={hardwareLoading}
|
||||
fetchHardware={fetchHardware}
|
||||
config={config}
|
||||
setConfig={setConfig}
|
||||
isGpu={isGpu}
|
||||
// Training
|
||||
serverPhase={serverPhase}
|
||||
statusMessage={statusMessage}
|
||||
currentEpoch={currentEpoch}
|
||||
bestAccuracy={bestAccuracy}
|
||||
datasetInfo={datasetInfo}
|
||||
result={result}
|
||||
error={error}
|
||||
// Actions
|
||||
onProgress={onProgress}
|
||||
onStartTraining={onStartTraining}
|
||||
onCancel={onCancel}
|
||||
onTrainAgain={onTrainAgain}
|
||||
onSyncComplete={onSyncComplete}
|
||||
canStartTraining={canStartTraining}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Upcoming cards (right side) */}
|
||||
<div className={css({ display: 'flex', gap: 2 })}>
|
||||
{cards.slice(currentCardIndex + 1).map((cardId) => {
|
||||
const cardDef = CARDS[cardId]
|
||||
return (
|
||||
<CollapsedCard
|
||||
key={cardId}
|
||||
icon={cardDef.icon}
|
||||
title={cardDef.title}
|
||||
preview={getCardPreview(cardId)}
|
||||
status="upcoming"
|
||||
/>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
'use client'
|
||||
|
||||
import { css } from '../../../../../../styled-system/css'
|
||||
|
||||
interface CollapsedCardProps {
|
||||
icon: string
|
||||
title: string
|
||||
summary?: string
|
||||
// For upcoming cards - can be a simple string or rich preview with multiple lines
|
||||
preview?: {
|
||||
primary: string
|
||||
secondary?: string
|
||||
tertiary?: string
|
||||
} | string
|
||||
status: 'done' | 'upcoming'
|
||||
}
|
||||
|
||||
export function CollapsedCard({ icon, title, summary, preview, status }: CollapsedCardProps) {
|
||||
const isDone = status === 'done'
|
||||
|
||||
// Parse preview into lines
|
||||
const previewObj = typeof preview === 'string'
|
||||
? { primary: preview }
|
||||
: preview
|
||||
|
||||
// Rich preview means we need a larger card
|
||||
const hasRichPreview = !isDone && previewObj && (previewObj.secondary || previewObj.tertiary)
|
||||
|
||||
return (
|
||||
<div
|
||||
data-element="collapsed-card"
|
||||
data-status={status}
|
||||
title={title}
|
||||
className={css({
|
||||
width: hasRichPreview ? '90px' : '70px',
|
||||
minHeight: hasRichPreview ? '85px' : '70px',
|
||||
py: hasRichPreview ? 2 : 0,
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
gap: hasRichPreview ? 1 : 0.5,
|
||||
bg: 'gray.800',
|
||||
borderRadius: 'lg',
|
||||
border: '2px solid',
|
||||
borderColor: isDone ? 'green.700' : 'gray.700',
|
||||
opacity: isDone ? 1 : 0.7,
|
||||
transition: 'all 0.3s ease',
|
||||
})}
|
||||
>
|
||||
{/* Icon */}
|
||||
<span className={css({ fontSize: 'lg' })}>{icon}</span>
|
||||
|
||||
{isDone ? (
|
||||
/* Done state - simple summary */
|
||||
<span
|
||||
className={css({
|
||||
fontSize: 'xs',
|
||||
color: 'green.400',
|
||||
fontWeight: 'medium',
|
||||
textAlign: 'center',
|
||||
px: 1,
|
||||
whiteSpace: 'nowrap',
|
||||
overflow: 'hidden',
|
||||
textOverflow: 'ellipsis',
|
||||
maxWidth: '100%',
|
||||
})}
|
||||
>
|
||||
{summary || '✓'}
|
||||
</span>
|
||||
) : previewObj ? (
|
||||
/* Upcoming state with preview */
|
||||
<div className={css({ display: 'flex', flexDirection: 'column', alignItems: 'center', gap: 0 })}>
|
||||
<span
|
||||
className={css({
|
||||
fontSize: 'xs',
|
||||
color: 'gray.300',
|
||||
fontWeight: 'medium',
|
||||
textAlign: 'center',
|
||||
whiteSpace: 'nowrap',
|
||||
})}
|
||||
>
|
||||
{previewObj.primary}
|
||||
</span>
|
||||
{previewObj.secondary && (
|
||||
<span
|
||||
className={css({
|
||||
fontSize: '10px',
|
||||
color: 'gray.500',
|
||||
textAlign: 'center',
|
||||
whiteSpace: 'nowrap',
|
||||
})}
|
||||
>
|
||||
{previewObj.secondary}
|
||||
</span>
|
||||
)}
|
||||
{previewObj.tertiary && (
|
||||
<span
|
||||
className={css({
|
||||
fontSize: '10px',
|
||||
color: 'gray.500',
|
||||
textAlign: 'center',
|
||||
whiteSpace: 'nowrap',
|
||||
})}
|
||||
>
|
||||
{previewObj.tertiary}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
/* Fallback to title */
|
||||
<span
|
||||
className={css({
|
||||
fontSize: 'xs',
|
||||
color: 'gray.400',
|
||||
fontWeight: 'medium',
|
||||
textAlign: 'center',
|
||||
})}
|
||||
>
|
||||
{title}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,177 @@
|
||||
'use client'
|
||||
|
||||
import { css } from '../../../../../../styled-system/css'
|
||||
import { DataCard } from './cards/DataCard'
|
||||
import { HardwareCard } from './cards/HardwareCard'
|
||||
import { ConfigCard } from './cards/ConfigCard'
|
||||
import { SetupCard } from './cards/SetupCard'
|
||||
import { LoadingCard } from './cards/LoadingCard'
|
||||
import { TrainingCard } from './cards/TrainingCard'
|
||||
import { ExportCard } from './cards/ExportCard'
|
||||
import { ResultsCard } from './cards/ResultsCard'
|
||||
import {
|
||||
CARDS,
|
||||
type CardId,
|
||||
type SamplesData,
|
||||
type HardwareInfo,
|
||||
type TrainingConfig,
|
||||
type ServerPhase,
|
||||
type EpochData,
|
||||
type DatasetInfo,
|
||||
type TrainingResult,
|
||||
} from './types'
|
||||
|
||||
interface ExpandedCardProps {
|
||||
cardId: CardId
|
||||
// Data
|
||||
samples: SamplesData | null
|
||||
samplesLoading: boolean
|
||||
hardwareInfo: HardwareInfo | null
|
||||
hardwareLoading: boolean
|
||||
fetchHardware: () => void
|
||||
config: TrainingConfig
|
||||
setConfig: (config: TrainingConfig | ((prev: TrainingConfig) => TrainingConfig)) => void
|
||||
isGpu: boolean
|
||||
// Training
|
||||
serverPhase: ServerPhase
|
||||
statusMessage: string
|
||||
currentEpoch: EpochData | null
|
||||
bestAccuracy: number
|
||||
datasetInfo: DatasetInfo | null
|
||||
result: TrainingResult | null
|
||||
error: string | null
|
||||
// Actions
|
||||
onProgress: () => void
|
||||
onStartTraining: () => void
|
||||
onCancel: () => void
|
||||
onTrainAgain: () => void
|
||||
onSyncComplete?: () => void
|
||||
canStartTraining: boolean
|
||||
}
|
||||
|
||||
export function ExpandedCard({
|
||||
cardId,
|
||||
samples,
|
||||
samplesLoading,
|
||||
hardwareInfo,
|
||||
hardwareLoading,
|
||||
fetchHardware,
|
||||
config,
|
||||
setConfig,
|
||||
isGpu,
|
||||
serverPhase,
|
||||
statusMessage,
|
||||
currentEpoch,
|
||||
bestAccuracy,
|
||||
datasetInfo,
|
||||
result,
|
||||
error,
|
||||
onProgress,
|
||||
onStartTraining,
|
||||
onCancel,
|
||||
onTrainAgain,
|
||||
onSyncComplete,
|
||||
canStartTraining,
|
||||
}: ExpandedCardProps) {
|
||||
const cardDef = CARDS[cardId]
|
||||
|
||||
const renderCardContent = () => {
|
||||
switch (cardId) {
|
||||
case 'data':
|
||||
return (
|
||||
<DataCard
|
||||
samples={samples}
|
||||
samplesLoading={samplesLoading}
|
||||
onProgress={onProgress}
|
||||
onSyncComplete={onSyncComplete}
|
||||
/>
|
||||
)
|
||||
case 'hardware':
|
||||
return (
|
||||
<HardwareCard
|
||||
hardwareInfo={hardwareInfo}
|
||||
hardwareLoading={hardwareLoading}
|
||||
fetchHardware={fetchHardware}
|
||||
onProgress={onProgress}
|
||||
/>
|
||||
)
|
||||
case 'config':
|
||||
return (
|
||||
<ConfigCard
|
||||
config={config}
|
||||
setConfig={setConfig}
|
||||
isGpu={isGpu}
|
||||
onStartTraining={onStartTraining}
|
||||
canStart={canStartTraining}
|
||||
/>
|
||||
)
|
||||
case 'setup':
|
||||
return <SetupCard message={statusMessage} />
|
||||
case 'loading':
|
||||
return <LoadingCard datasetInfo={datasetInfo} message={statusMessage} />
|
||||
case 'training':
|
||||
return (
|
||||
<TrainingCard
|
||||
currentEpoch={currentEpoch}
|
||||
totalEpochs={config.epochs}
|
||||
bestAccuracy={bestAccuracy}
|
||||
onCancel={onCancel}
|
||||
/>
|
||||
)
|
||||
case 'export':
|
||||
return <ExportCard message={statusMessage} />
|
||||
case 'results':
|
||||
return (
|
||||
<ResultsCard
|
||||
result={result}
|
||||
error={error}
|
||||
onTrainAgain={onTrainAgain}
|
||||
/>
|
||||
)
|
||||
default:
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
data-element="expanded-card"
|
||||
data-card={cardId}
|
||||
className={css({
|
||||
flex: '1 1 300px',
|
||||
maxWidth: '400px',
|
||||
minHeight: '200px',
|
||||
bg: 'gray.800',
|
||||
borderRadius: 'xl',
|
||||
border: '2px solid',
|
||||
borderColor: 'blue.500',
|
||||
boxShadow: 'lg',
|
||||
overflow: 'hidden',
|
||||
transition: 'all 0.3s ease',
|
||||
})}
|
||||
>
|
||||
{/* Card Header */}
|
||||
<div
|
||||
className={css({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 2,
|
||||
p: 3,
|
||||
borderBottom: '1px solid',
|
||||
borderColor: 'gray.700',
|
||||
bg: 'gray.850',
|
||||
})}
|
||||
>
|
||||
<span className={css({ fontSize: 'lg' })}>{cardDef.icon}</span>
|
||||
<span className={css({ fontWeight: 'semibold', color: 'gray.100' })}>
|
||||
{cardDef.title}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Card Content */}
|
||||
<div className={css({ p: 4 })}>
|
||||
{renderCardContent()}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,238 @@
|
||||
'use client'
|
||||
|
||||
import { css } from '../../../../../../styled-system/css'
|
||||
import { CardCarousel } from './CardCarousel'
|
||||
import { StepProgress } from './StepProgress'
|
||||
import {
|
||||
CARDS,
|
||||
type PhaseDefinition,
|
||||
type PhaseStatus,
|
||||
type SamplesData,
|
||||
type HardwareInfo,
|
||||
type TrainingConfig,
|
||||
type ServerPhase,
|
||||
type EpochData,
|
||||
type DatasetInfo,
|
||||
type TrainingResult,
|
||||
} from './types'
|
||||
|
||||
interface PhaseSectionProps {
|
||||
phase: PhaseDefinition
|
||||
status: PhaseStatus
|
||||
currentCardIndex: number
|
||||
// Data
|
||||
samples: SamplesData | null
|
||||
samplesLoading: boolean
|
||||
hardwareInfo: HardwareInfo | null
|
||||
hardwareLoading: boolean
|
||||
fetchHardware: () => void
|
||||
config: TrainingConfig
|
||||
setConfig: (config: TrainingConfig | ((prev: TrainingConfig) => TrainingConfig)) => void
|
||||
isGpu: boolean
|
||||
// Training
|
||||
serverPhase: ServerPhase
|
||||
statusMessage: string
|
||||
currentEpoch: EpochData | null
|
||||
bestAccuracy: number
|
||||
datasetInfo: DatasetInfo | null
|
||||
result: TrainingResult | null
|
||||
error: string | null
|
||||
// Summaries
|
||||
getCardSummary: (cardId: string) => { label: string; value: string } | null
|
||||
// Actions
|
||||
onProgress: () => void
|
||||
onStartTraining: () => void
|
||||
onCancel: () => void
|
||||
onTrainAgain: () => void
|
||||
onSyncComplete?: () => void
|
||||
canStartTraining: boolean
|
||||
}
|
||||
|
||||
const STATUS_STYLES: Record<PhaseStatus, { borderColor: string; bg: string; opacity: number }> = {
|
||||
done: { borderColor: 'green.600', bg: 'gray.850', opacity: 1 },
|
||||
current: { borderColor: 'blue.500', bg: 'gray.800', opacity: 1 },
|
||||
upcoming: { borderColor: 'gray.700', bg: 'gray.850', opacity: 0.6 },
|
||||
}
|
||||
|
||||
export function PhaseSection({
|
||||
phase,
|
||||
status,
|
||||
currentCardIndex,
|
||||
samples,
|
||||
samplesLoading,
|
||||
hardwareInfo,
|
||||
hardwareLoading,
|
||||
fetchHardware,
|
||||
config,
|
||||
setConfig,
|
||||
isGpu,
|
||||
serverPhase,
|
||||
statusMessage,
|
||||
currentEpoch,
|
||||
bestAccuracy,
|
||||
datasetInfo,
|
||||
result,
|
||||
error,
|
||||
getCardSummary,
|
||||
onProgress,
|
||||
onStartTraining,
|
||||
onCancel,
|
||||
onTrainAgain,
|
||||
onSyncComplete,
|
||||
canStartTraining,
|
||||
}: PhaseSectionProps) {
|
||||
const styles = STATUS_STYLES[status]
|
||||
|
||||
return (
|
||||
<div
|
||||
data-element={`phase-${phase.id}`}
|
||||
data-status={status}
|
||||
className={css({
|
||||
borderLeft: '3px solid',
|
||||
borderColor: styles.borderColor,
|
||||
bg: styles.bg,
|
||||
borderRadius: 'lg',
|
||||
overflow: 'hidden',
|
||||
opacity: styles.opacity,
|
||||
transition: 'all 0.3s ease',
|
||||
})}
|
||||
>
|
||||
{/* Phase Header */}
|
||||
<div
|
||||
className={css({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'space-between',
|
||||
p: 3,
|
||||
borderBottom: status === 'current' ? '1px solid' : 'none',
|
||||
borderColor: 'gray.700',
|
||||
})}
|
||||
>
|
||||
<div className={css({ display: 'flex', alignItems: 'center', gap: 2 })}>
|
||||
{/* Status indicator */}
|
||||
<div
|
||||
className={css({
|
||||
width: '20px',
|
||||
height: '20px',
|
||||
borderRadius: 'full',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
fontSize: 'xs',
|
||||
fontWeight: 'bold',
|
||||
bg: status === 'done' ? 'green.600' : status === 'current' ? 'blue.600' : 'gray.600',
|
||||
color: 'white',
|
||||
})}
|
||||
>
|
||||
{status === 'done' ? '✓' : status === 'current' ? '●' : '○'}
|
||||
</div>
|
||||
<span
|
||||
className={css({
|
||||
fontWeight: 'semibold',
|
||||
fontSize: 'sm',
|
||||
color: status === 'current' ? 'gray.100' : 'gray.400',
|
||||
})}
|
||||
>
|
||||
{phase.title}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Status badge */}
|
||||
<span
|
||||
className={css({
|
||||
fontSize: 'xs',
|
||||
px: 2,
|
||||
py: 0.5,
|
||||
borderRadius: 'full',
|
||||
bg: status === 'done' ? 'green.800' : status === 'current' ? 'blue.800' : 'gray.700',
|
||||
color: status === 'done' ? 'green.300' : status === 'current' ? 'blue.300' : 'gray.500',
|
||||
})}
|
||||
>
|
||||
{status === 'done' ? 'Complete' : status === 'current' ? 'In Progress' : 'Upcoming'}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Phase Content */}
|
||||
{status === 'current' ? (
|
||||
<div className={css({ p: 4 })}>
|
||||
{/* Card Carousel */}
|
||||
<CardCarousel
|
||||
cards={phase.cards}
|
||||
currentCardIndex={currentCardIndex}
|
||||
// Data
|
||||
samples={samples}
|
||||
samplesLoading={samplesLoading}
|
||||
hardwareInfo={hardwareInfo}
|
||||
hardwareLoading={hardwareLoading}
|
||||
fetchHardware={fetchHardware}
|
||||
config={config}
|
||||
setConfig={setConfig}
|
||||
isGpu={isGpu}
|
||||
// Training
|
||||
serverPhase={serverPhase}
|
||||
statusMessage={statusMessage}
|
||||
currentEpoch={currentEpoch}
|
||||
bestAccuracy={bestAccuracy}
|
||||
datasetInfo={datasetInfo}
|
||||
result={result}
|
||||
error={error}
|
||||
// Summaries
|
||||
getCardSummary={getCardSummary}
|
||||
// Actions
|
||||
onProgress={onProgress}
|
||||
onStartTraining={onStartTraining}
|
||||
onCancel={onCancel}
|
||||
onTrainAgain={onTrainAgain}
|
||||
onSyncComplete={onSyncComplete}
|
||||
canStartTraining={canStartTraining}
|
||||
/>
|
||||
|
||||
{/* Step Progress - only for multi-card phases */}
|
||||
{phase.cards.length > 1 && (
|
||||
<StepProgress
|
||||
steps={phase.cards.map((cardId) => CARDS[cardId].title)}
|
||||
currentIndex={currentCardIndex}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
/* Collapsed summary for done/upcoming */
|
||||
<div className={css({ px: 4, py: 2 })}>
|
||||
{status === 'done' ? (
|
||||
<div className={css({ display: 'flex', gap: 3, flexWrap: 'wrap' })}>
|
||||
{phase.cards.map((cardId) => {
|
||||
const summary = getCardSummary(cardId)
|
||||
const cardDef = CARDS[cardId]
|
||||
return (
|
||||
<div
|
||||
key={cardId}
|
||||
className={css({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 1.5,
|
||||
px: 2,
|
||||
py: 1,
|
||||
bg: 'gray.800',
|
||||
borderRadius: 'md',
|
||||
fontSize: 'xs',
|
||||
})}
|
||||
>
|
||||
<span>{cardDef.icon}</span>
|
||||
<span className={css({ color: 'gray.400' })}>{cardDef.title}:</span>
|
||||
<span className={css({ color: 'green.400', fontWeight: 'medium' })}>
|
||||
{summary?.value || '✓'}
|
||||
</span>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
) : (
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500', fontStyle: 'italic' })}>
|
||||
{phase.id === 'results' ? 'Waiting for training to complete...' : 'Waiting...'}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
'use client'
|
||||
|
||||
import { css } from '../../../../../../styled-system/css'
|
||||
|
||||
interface StepProgressProps {
|
||||
steps: string[]
|
||||
currentIndex: number
|
||||
}
|
||||
|
||||
export function StepProgress({ steps, currentIndex }: StepProgressProps) {
|
||||
return (
|
||||
<div
|
||||
data-element="step-progress"
|
||||
className={css({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
gap: 1,
|
||||
mt: 2,
|
||||
})}
|
||||
>
|
||||
{steps.map((step, index) => {
|
||||
const isDone = index < currentIndex
|
||||
const isCurrent = index === currentIndex
|
||||
const isUpcoming = index > currentIndex
|
||||
|
||||
return (
|
||||
<div key={step} className={css({ display: 'flex', alignItems: 'center' })}>
|
||||
{/* Dot */}
|
||||
<div
|
||||
title={step}
|
||||
className={css({
|
||||
width: isCurrent ? '10px' : '8px',
|
||||
height: isCurrent ? '10px' : '8px',
|
||||
borderRadius: 'full',
|
||||
bg: isDone ? 'green.500' : isCurrent ? 'blue.500' : 'gray.600',
|
||||
transition: 'all 0.2s ease',
|
||||
cursor: 'help',
|
||||
})}
|
||||
/>
|
||||
|
||||
{/* Connector line (if not last) */}
|
||||
{index < steps.length - 1 && (
|
||||
<div
|
||||
className={css({
|
||||
width: '20px',
|
||||
height: '2px',
|
||||
bg: isDone ? 'green.500' : 'gray.700',
|
||||
mx: 0.5,
|
||||
})}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
'use client'
|
||||
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import { css } from '../../../../../../styled-system/css'
|
||||
import { PhaseSection } from './PhaseSection'
|
||||
import {
|
||||
PHASES,
|
||||
serverPhaseToWizardPosition,
|
||||
type SamplesData,
|
||||
type HardwareInfo,
|
||||
type TrainingConfig,
|
||||
type ServerPhase,
|
||||
type EpochData,
|
||||
type DatasetInfo,
|
||||
type TrainingResult,
|
||||
type PhaseStatus,
|
||||
} from './types'
|
||||
|
||||
interface TrainingWizardProps {
|
||||
// Data state
|
||||
samples: SamplesData | null
|
||||
samplesLoading: boolean
|
||||
// Hardware state
|
||||
hardwareInfo: HardwareInfo | null
|
||||
hardwareLoading: boolean
|
||||
fetchHardware: () => void
|
||||
// Config state
|
||||
config: TrainingConfig
|
||||
setConfig: (config: TrainingConfig | ((prev: TrainingConfig) => TrainingConfig)) => void
|
||||
// Training state (from server)
|
||||
serverPhase: ServerPhase
|
||||
statusMessage: string
|
||||
currentEpoch: EpochData | null
|
||||
epochHistory: EpochData[]
|
||||
datasetInfo: DatasetInfo | null
|
||||
result: TrainingResult | null
|
||||
error: string | null
|
||||
// Actions
|
||||
onStart: () => void
|
||||
onCancel: () => void
|
||||
onReset: () => void
|
||||
onSyncComplete?: () => void
|
||||
}
|
||||
|
||||
export function TrainingWizard({
|
||||
samples,
|
||||
samplesLoading,
|
||||
hardwareInfo,
|
||||
hardwareLoading,
|
||||
fetchHardware,
|
||||
config,
|
||||
setConfig,
|
||||
serverPhase,
|
||||
statusMessage,
|
||||
currentEpoch,
|
||||
epochHistory,
|
||||
datasetInfo,
|
||||
result,
|
||||
error,
|
||||
onStart,
|
||||
onCancel,
|
||||
onReset,
|
||||
onSyncComplete,
|
||||
}: TrainingWizardProps) {
|
||||
// Wizard position state
|
||||
const [currentPhaseIndex, setCurrentPhaseIndex] = useState(0)
|
||||
const [currentCardIndex, setCurrentCardIndex] = useState(0)
|
||||
|
||||
// Derive state
|
||||
const isGpu = hardwareInfo?.deviceType === 'gpu'
|
||||
const bestAccuracy = epochHistory.length > 0 ? Math.max(...epochHistory.map((e) => e.val_accuracy)) : 0
|
||||
const hasEnoughData = samples?.hasData && samples.dataQuality !== 'none' && samples.dataQuality !== 'insufficient'
|
||||
|
||||
// Sync wizard position with server phase during training
|
||||
useEffect(() => {
|
||||
if (serverPhase !== 'idle') {
|
||||
const { phaseIndex, cardIndex } = serverPhaseToWizardPosition(serverPhase)
|
||||
setCurrentPhaseIndex(phaseIndex)
|
||||
setCurrentCardIndex(cardIndex)
|
||||
}
|
||||
}, [serverPhase])
|
||||
|
||||
// Progress to next card
|
||||
const progressToNextCard = useCallback(() => {
|
||||
const currentPhase = PHASES[currentPhaseIndex]
|
||||
|
||||
if (currentCardIndex < currentPhase.cards.length - 1) {
|
||||
// More cards in this phase
|
||||
setCurrentCardIndex((prev) => prev + 1)
|
||||
} else if (currentPhaseIndex < PHASES.length - 1) {
|
||||
// Move to next phase
|
||||
setCurrentPhaseIndex((prev) => prev + 1)
|
||||
setCurrentCardIndex(0)
|
||||
}
|
||||
}, [currentPhaseIndex, currentCardIndex])
|
||||
|
||||
// Handle starting training (transition from config to training phase)
|
||||
const handleStartTraining = useCallback(() => {
|
||||
// Move to training phase
|
||||
setCurrentPhaseIndex(1)
|
||||
setCurrentCardIndex(0)
|
||||
// Trigger actual training
|
||||
onStart()
|
||||
}, [onStart])
|
||||
|
||||
// Handle train again (reset to preparation)
|
||||
const handleTrainAgain = useCallback(() => {
|
||||
setCurrentPhaseIndex(0)
|
||||
setCurrentCardIndex(0)
|
||||
onReset()
|
||||
}, [onReset])
|
||||
|
||||
// Get phase status based on current position
|
||||
const getPhaseStatus = (phaseIndex: number): PhaseStatus => {
|
||||
if (phaseIndex < currentPhaseIndex) return 'done'
|
||||
if (phaseIndex === currentPhaseIndex) return 'current'
|
||||
return 'upcoming'
|
||||
}
|
||||
|
||||
// Card summaries for done states
|
||||
const getCardSummary = (cardId: string): { label: string; value: string } | null => {
|
||||
switch (cardId) {
|
||||
case 'data':
|
||||
if (!samples?.hasData) return null
|
||||
return { label: 'Images', value: `${samples.totalImages}` }
|
||||
case 'hardware':
|
||||
if (!hardwareInfo) return null
|
||||
return {
|
||||
label: hardwareInfo.deviceType === 'gpu' ? 'GPU' : 'CPU',
|
||||
value: hardwareInfo.deviceName.split(' ').slice(0, 2).join(' ')
|
||||
}
|
||||
case 'config':
|
||||
return { label: 'Epochs', value: `${config.epochs}` }
|
||||
case 'setup':
|
||||
return { label: 'Ready', value: '✓' }
|
||||
case 'loading':
|
||||
return { label: 'Loaded', value: datasetInfo ? `${datasetInfo.total_images}` : '✓' }
|
||||
case 'training':
|
||||
return { label: 'Accuracy', value: `${(bestAccuracy * 100).toFixed(0)}%` }
|
||||
case 'export':
|
||||
return { label: 'Exported', value: '✓' }
|
||||
case 'results':
|
||||
return result ? { label: 'Final', value: `${(result.final_accuracy * 100).toFixed(1)}%` } : null
|
||||
default:
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
data-component="training-wizard"
|
||||
className={css({
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: 4,
|
||||
})}
|
||||
>
|
||||
{PHASES.map((phase, phaseIndex) => (
|
||||
<PhaseSection
|
||||
key={phase.id}
|
||||
phase={phase}
|
||||
status={getPhaseStatus(phaseIndex)}
|
||||
currentCardIndex={phaseIndex === currentPhaseIndex ? currentCardIndex : -1}
|
||||
// Data for cards
|
||||
samples={samples}
|
||||
samplesLoading={samplesLoading}
|
||||
hardwareInfo={hardwareInfo}
|
||||
hardwareLoading={hardwareLoading}
|
||||
fetchHardware={fetchHardware}
|
||||
config={config}
|
||||
setConfig={setConfig}
|
||||
isGpu={isGpu}
|
||||
// Training data
|
||||
serverPhase={serverPhase}
|
||||
statusMessage={statusMessage}
|
||||
currentEpoch={currentEpoch}
|
||||
bestAccuracy={bestAccuracy}
|
||||
datasetInfo={datasetInfo}
|
||||
result={result}
|
||||
error={error}
|
||||
// Summaries
|
||||
getCardSummary={getCardSummary}
|
||||
// Actions
|
||||
onProgress={progressToNextCard}
|
||||
onStartTraining={handleStartTraining}
|
||||
onCancel={onCancel}
|
||||
onTrainAgain={handleTrainAgain}
|
||||
onSyncComplete={onSyncComplete}
|
||||
// Validation
|
||||
canStartTraining={!!hasEnoughData && !hardwareLoading && !hardwareInfo?.error}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
'use client'
|
||||
|
||||
import { css } from '../../../../../../../styled-system/css'
|
||||
import type { TrainingConfig } from '../types'
|
||||
|
||||
interface ConfigCardProps {
|
||||
config: TrainingConfig
|
||||
setConfig: (config: TrainingConfig | ((prev: TrainingConfig) => TrainingConfig)) => void
|
||||
isGpu: boolean
|
||||
onStartTraining: () => void
|
||||
canStart: boolean
|
||||
}
|
||||
|
||||
interface Preset {
|
||||
epochs: number
|
||||
batchSize: number
|
||||
label: string
|
||||
desc: string
|
||||
}
|
||||
|
||||
export function ConfigCard({
|
||||
config,
|
||||
setConfig,
|
||||
isGpu,
|
||||
onStartTraining,
|
||||
canStart,
|
||||
}: ConfigCardProps) {
|
||||
// Hardware-aware presets
|
||||
const presets: Record<string, Preset> = isGpu
|
||||
? {
|
||||
quick: { epochs: 10, batchSize: 32, label: '⚡ Quick', desc: '~2 min' },
|
||||
balanced: { epochs: 50, batchSize: 64, label: '⚖️ Balanced', desc: '~10 min' },
|
||||
best: { epochs: 100, batchSize: 64, label: '✨ Best', desc: '~20 min' },
|
||||
}
|
||||
: {
|
||||
quick: { epochs: 5, batchSize: 16, label: '⚡ Quick', desc: '~5 min' },
|
||||
balanced: { epochs: 25, batchSize: 32, label: '⚖️ Balanced', desc: '~15 min' },
|
||||
best: { epochs: 50, batchSize: 32, label: '✨ Best', desc: '~30 min' },
|
||||
}
|
||||
|
||||
const applyPreset = (preset: Preset) => {
|
||||
setConfig((prev) => ({
|
||||
...prev,
|
||||
epochs: preset.epochs,
|
||||
batchSize: preset.batchSize,
|
||||
}))
|
||||
}
|
||||
|
||||
const isPresetActive = (preset: Preset) =>
|
||||
config.epochs === preset.epochs && config.batchSize === preset.batchSize
|
||||
|
||||
return (
|
||||
<div>
|
||||
{/* Presets */}
|
||||
<div className={css({ mb: 4 })}>
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500', mb: 2 })}>Presets</div>
|
||||
<div className={css({ display: 'flex', gap: 2 })}>
|
||||
{Object.entries(presets).map(([key, preset]) => (
|
||||
<button
|
||||
key={key}
|
||||
type="button"
|
||||
onClick={() => applyPreset(preset)}
|
||||
className={css({
|
||||
flex: 1,
|
||||
py: 2,
|
||||
px: 2,
|
||||
borderRadius: 'lg',
|
||||
border: '2px solid',
|
||||
borderColor: isPresetActive(preset) ? 'blue.500' : 'gray.700',
|
||||
bg: isPresetActive(preset) ? 'blue.900' : 'gray.800',
|
||||
color: isPresetActive(preset) ? 'blue.300' : 'gray.300',
|
||||
cursor: 'pointer',
|
||||
transition: 'all 0.2s',
|
||||
_hover: { borderColor: 'blue.400' },
|
||||
})}
|
||||
>
|
||||
<div className={css({ fontSize: 'sm', fontWeight: 'medium' })}>{preset.label}</div>
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500' })}>{preset.desc}</div>
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Epochs slider */}
|
||||
<div className={css({ mb: 4 })}>
|
||||
<div className={css({ display: 'flex', justifyContent: 'space-between', mb: 1 })}>
|
||||
<span className={css({ fontSize: 'xs', color: 'gray.500' })}>Training Rounds</span>
|
||||
<span className={css({ fontSize: 'sm', fontWeight: 'medium', color: 'gray.200' })}>
|
||||
{config.epochs}
|
||||
</span>
|
||||
</div>
|
||||
<input
|
||||
type="range"
|
||||
min={5}
|
||||
max={isGpu ? 150 : 75}
|
||||
value={config.epochs}
|
||||
onChange={(e) => setConfig((prev) => ({ ...prev, epochs: parseInt(e.target.value, 10) }))}
|
||||
className={css({
|
||||
width: '100%',
|
||||
height: '6px',
|
||||
borderRadius: 'full',
|
||||
bg: 'gray.700',
|
||||
appearance: 'none',
|
||||
cursor: 'pointer',
|
||||
'&::-webkit-slider-thumb': {
|
||||
appearance: 'none',
|
||||
width: '16px',
|
||||
height: '16px',
|
||||
borderRadius: 'full',
|
||||
bg: 'blue.500',
|
||||
cursor: 'pointer',
|
||||
},
|
||||
})}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Batch size */}
|
||||
<div className={css({ mb: 4 })}>
|
||||
<div className={css({ display: 'flex', justifyContent: 'space-between', mb: 1 })}>
|
||||
<span className={css({ fontSize: 'xs', color: 'gray.500' })}>Batch Size</span>
|
||||
<span className={css({ fontSize: 'sm', fontWeight: 'medium', color: 'gray.200' })}>
|
||||
{config.batchSize}
|
||||
</span>
|
||||
</div>
|
||||
<div className={css({ display: 'flex', gap: 2 })}>
|
||||
{(isGpu ? [32, 64, 128] : [16, 32, 64]).map((size) => (
|
||||
<button
|
||||
key={size}
|
||||
type="button"
|
||||
onClick={() => setConfig((prev) => ({ ...prev, batchSize: size }))}
|
||||
className={css({
|
||||
flex: 1,
|
||||
py: 1.5,
|
||||
borderRadius: 'md',
|
||||
border: '1px solid',
|
||||
borderColor: config.batchSize === size ? 'blue.500' : 'gray.700',
|
||||
bg: config.batchSize === size ? 'blue.900' : 'transparent',
|
||||
color: config.batchSize === size ? 'blue.300' : 'gray.400',
|
||||
fontSize: 'sm',
|
||||
cursor: 'pointer',
|
||||
_hover: { borderColor: 'blue.400' },
|
||||
})}
|
||||
>
|
||||
{size}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Augmentation toggle */}
|
||||
<div className={css({ mb: 6 })}>
|
||||
<label
|
||||
className={css({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 2,
|
||||
cursor: 'pointer',
|
||||
})}
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={config.augmentation}
|
||||
onChange={(e) => setConfig((prev) => ({ ...prev, augmentation: e.target.checked }))}
|
||||
className={css({
|
||||
width: '18px',
|
||||
height: '18px',
|
||||
accentColor: 'rgb(59, 130, 246)',
|
||||
})}
|
||||
/>
|
||||
<span className={css({ fontSize: 'sm', color: 'gray.200' })}>Data Augmentation</span>
|
||||
</label>
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500', ml: 6, mt: 0.5 })}>
|
||||
{isGpu ? 'Recommended - your GPU handles extra data easily' : 'Adds processing time but improves results'}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Start button */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={onStartTraining}
|
||||
disabled={!canStart}
|
||||
className={css({
|
||||
width: '100%',
|
||||
py: 3,
|
||||
bg: canStart ? 'green.600' : 'gray.700',
|
||||
color: canStart ? 'white' : 'gray.500',
|
||||
borderRadius: 'lg',
|
||||
border: 'none',
|
||||
cursor: canStart ? 'pointer' : 'not-allowed',
|
||||
fontWeight: 'bold',
|
||||
fontSize: 'md',
|
||||
transition: 'all 0.2s',
|
||||
_hover: canStart ? { bg: 'green.500' } : {},
|
||||
})}
|
||||
>
|
||||
{canStart ? 'Start Training →' : 'Complete previous steps first'}
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,434 @@
|
||||
'use client'
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import Link from 'next/link'
|
||||
import { css } from '../../../../../../../styled-system/css'
|
||||
import type { SamplesData } from '../types'
|
||||
|
||||
interface DataCardProps {
|
||||
samples: SamplesData | null
|
||||
samplesLoading: boolean
|
||||
onProgress: () => void
|
||||
onSyncComplete?: () => void // Callback to refresh samples after sync
|
||||
}
|
||||
|
||||
interface SyncStatus {
|
||||
available: boolean
|
||||
remote?: { host: string; totalImages: number }
|
||||
local?: { totalImages: number }
|
||||
needsSync?: boolean
|
||||
error?: string
|
||||
}
|
||||
|
||||
interface SyncProgress {
|
||||
phase: 'idle' | 'connecting' | 'syncing' | 'complete' | 'error'
|
||||
message: string
|
||||
filesTransferred?: number
|
||||
bytesTransferred?: number
|
||||
}
|
||||
|
||||
const QUALITY_CONFIG: Record<SamplesData['dataQuality'], { color: string; label: string; barWidth: string }> = {
|
||||
none: { color: 'gray.500', label: 'No Data', barWidth: '0%' },
|
||||
insufficient: { color: 'red.400', label: 'Need More', barWidth: '20%' },
|
||||
minimal: { color: 'yellow.400', label: 'Minimal', barWidth: '50%' },
|
||||
good: { color: 'green.400', label: 'Good', barWidth: '80%' },
|
||||
excellent: { color: 'green.300', label: 'Excellent', barWidth: '100%' },
|
||||
}
|
||||
|
||||
export function DataCard({ samples, samplesLoading, onProgress, onSyncComplete }: DataCardProps) {
|
||||
const [syncStatus, setSyncStatus] = useState<SyncStatus | null>(null)
|
||||
const [syncProgress, setSyncProgress] = useState<SyncProgress>({ phase: 'idle', message: '' })
|
||||
const [syncChecking, setSyncChecking] = useState(true)
|
||||
const abortRef = useRef<AbortController | null>(null)
|
||||
|
||||
const isReady = samples?.hasData && samples.dataQuality !== 'none' && samples.dataQuality !== 'insufficient'
|
||||
const isSyncing = syncProgress.phase === 'connecting' || syncProgress.phase === 'syncing'
|
||||
|
||||
// Check sync availability on mount
|
||||
useEffect(() => {
|
||||
const checkSync = async () => {
|
||||
setSyncChecking(true)
|
||||
try {
|
||||
const response = await fetch('/api/vision-training/sync')
|
||||
const data = await response.json()
|
||||
setSyncStatus(data)
|
||||
} catch {
|
||||
setSyncStatus({ available: false, error: 'Failed to check sync status' })
|
||||
} finally {
|
||||
setSyncChecking(false)
|
||||
}
|
||||
}
|
||||
checkSync()
|
||||
}, [])
|
||||
|
||||
// Start sync
|
||||
const startSync = useCallback(async () => {
|
||||
setSyncProgress({ phase: 'connecting', message: 'Connecting to production...' })
|
||||
abortRef.current = new AbortController()
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/vision-training/sync', {
|
||||
method: 'POST',
|
||||
signal: abortRef.current.signal,
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to start sync')
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) throw new Error('No response body')
|
||||
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
let eventType = ''
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('event: ')) {
|
||||
eventType = line.slice(7)
|
||||
} else if (line.startsWith('data: ')) {
|
||||
const data = JSON.parse(line.slice(6))
|
||||
handleSyncEvent(eventType, data)
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if ((error as Error).name === 'AbortError') {
|
||||
setSyncProgress({ phase: 'idle', message: '' })
|
||||
} else {
|
||||
setSyncProgress({
|
||||
phase: 'error',
|
||||
message: error instanceof Error ? error.message : 'Sync failed',
|
||||
})
|
||||
}
|
||||
}
|
||||
}, [])
|
||||
|
||||
const handleSyncEvent = (eventType: string, data: Record<string, unknown>) => {
|
||||
switch (eventType) {
|
||||
case 'status':
|
||||
setSyncProgress({
|
||||
phase: data.phase as SyncProgress['phase'],
|
||||
message: data.message as string,
|
||||
})
|
||||
break
|
||||
case 'progress':
|
||||
setSyncProgress({
|
||||
phase: 'syncing',
|
||||
message: data.message as string,
|
||||
filesTransferred: data.filesTransferred as number,
|
||||
bytesTransferred: data.bytesTransferred as number,
|
||||
})
|
||||
break
|
||||
case 'complete':
|
||||
setSyncProgress({
|
||||
phase: 'complete',
|
||||
message: `Synced ${data.filesTransferred} files`,
|
||||
filesTransferred: data.filesTransferred as number,
|
||||
})
|
||||
// Update sync status with new local counts
|
||||
setSyncStatus((prev) =>
|
||||
prev
|
||||
? {
|
||||
...prev,
|
||||
local: { totalImages: data.totalImages as number },
|
||||
needsSync: false,
|
||||
}
|
||||
: null
|
||||
)
|
||||
// Trigger parent to refresh samples
|
||||
onSyncComplete?.()
|
||||
break
|
||||
case 'error':
|
||||
setSyncProgress({
|
||||
phase: 'error',
|
||||
message: data.message as string,
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
const cancelSync = useCallback(() => {
|
||||
abortRef.current?.abort()
|
||||
setSyncProgress({ phase: 'idle', message: '' })
|
||||
}, [])
|
||||
|
||||
if (samplesLoading && syncChecking) {
|
||||
return (
|
||||
<div className={css({ textAlign: 'center', py: 4 })}>
|
||||
<span className={css({ fontSize: 'lg', animation: 'spin 1s linear infinite' })}>⏳</span>
|
||||
<div className={css({ color: 'gray.400', mt: 2 })}>Loading...</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// Show sync UI prominently if sync is available and needed
|
||||
const showSyncUI = syncStatus?.available && (syncStatus.needsSync || !samples?.hasData)
|
||||
|
||||
return (
|
||||
<div>
|
||||
{/* Sync from Production Section */}
|
||||
{showSyncUI && syncProgress.phase !== 'complete' && (
|
||||
<div
|
||||
className={css({
|
||||
mb: 4,
|
||||
p: 3,
|
||||
bg: 'blue.900/30',
|
||||
border: '1px solid',
|
||||
borderColor: 'blue.700',
|
||||
borderRadius: 'lg',
|
||||
})}
|
||||
>
|
||||
<div className={css({ display: 'flex', alignItems: 'center', gap: 2, mb: 2 })}>
|
||||
<span>☁️</span>
|
||||
<span className={css({ fontWeight: 'medium', color: 'blue.300' })}>
|
||||
Production Data Available
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{syncStatus.remote && (
|
||||
<div className={css({ fontSize: 'sm', color: 'gray.400', mb: 3 })}>
|
||||
<strong className={css({ color: 'blue.400' })}>
|
||||
{syncStatus.remote.totalImages.toLocaleString()}
|
||||
</strong>{' '}
|
||||
images on {syncStatus.remote.host}
|
||||
{syncStatus.local && syncStatus.local.totalImages > 0 && (
|
||||
<span>
|
||||
{' '}
|
||||
({(syncStatus.remote.totalImages - syncStatus.local.totalImages).toLocaleString()} new)
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Sync progress */}
|
||||
{isSyncing && (
|
||||
<div className={css({ mb: 3 })}>
|
||||
<div className={css({ display: 'flex', alignItems: 'center', gap: 2, mb: 2 })}>
|
||||
<span className={css({ animation: 'spin 1s linear infinite' })}>🔄</span>
|
||||
<span className={css({ fontSize: 'sm', color: 'gray.300' })}>{syncProgress.message}</span>
|
||||
</div>
|
||||
{syncProgress.filesTransferred !== undefined && syncProgress.filesTransferred > 0 && (
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500' })}>
|
||||
{syncProgress.filesTransferred} files transferred
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error state */}
|
||||
{syncProgress.phase === 'error' && (
|
||||
<div className={css({ color: 'red.400', fontSize: 'sm', mb: 3 })}>{syncProgress.message}</div>
|
||||
)}
|
||||
|
||||
{/* Action buttons */}
|
||||
<div className={css({ display: 'flex', gap: 2 })}>
|
||||
{!isSyncing ? (
|
||||
<>
|
||||
<button
|
||||
type="button"
|
||||
onClick={startSync}
|
||||
className={css({
|
||||
flex: 1,
|
||||
py: 2,
|
||||
bg: 'blue.600',
|
||||
color: 'white',
|
||||
borderRadius: 'lg',
|
||||
border: 'none',
|
||||
cursor: 'pointer',
|
||||
fontWeight: 'medium',
|
||||
_hover: { bg: 'blue.500' },
|
||||
})}
|
||||
>
|
||||
Sync Now
|
||||
</button>
|
||||
{samples?.hasData && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={onProgress}
|
||||
className={css({
|
||||
px: 4,
|
||||
py: 2,
|
||||
bg: 'transparent',
|
||||
color: 'gray.400',
|
||||
borderRadius: 'lg',
|
||||
border: '1px solid',
|
||||
borderColor: 'gray.600',
|
||||
cursor: 'pointer',
|
||||
_hover: { borderColor: 'gray.500', color: 'gray.300' },
|
||||
})}
|
||||
>
|
||||
Skip
|
||||
</button>
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
<button
|
||||
type="button"
|
||||
onClick={cancelSync}
|
||||
className={css({
|
||||
flex: 1,
|
||||
py: 2,
|
||||
bg: 'transparent',
|
||||
color: 'gray.400',
|
||||
borderRadius: 'lg',
|
||||
border: '1px solid',
|
||||
borderColor: 'gray.600',
|
||||
cursor: 'pointer',
|
||||
_hover: { borderColor: 'gray.500', color: 'gray.300' },
|
||||
})}
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Sync complete message */}
|
||||
{syncProgress.phase === 'complete' && (
|
||||
<div
|
||||
className={css({
|
||||
mb: 4,
|
||||
p: 3,
|
||||
bg: 'green.900/30',
|
||||
border: '1px solid',
|
||||
borderColor: 'green.700',
|
||||
borderRadius: 'lg',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 2,
|
||||
})}
|
||||
>
|
||||
<span>✅</span>
|
||||
<span className={css({ color: 'green.400', fontSize: 'sm' })}>{syncProgress.message}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* No data state - only if sync not available */}
|
||||
{!samples?.hasData && !showSyncUI && (
|
||||
<div className={css({ textAlign: 'center', py: 4 })}>
|
||||
<div className={css({ fontSize: '2xl', mb: 2 })}>📷</div>
|
||||
<div className={css({ color: 'gray.300', mb: 2 })}>No training data collected yet</div>
|
||||
<Link
|
||||
href="/vision-training"
|
||||
className={css({
|
||||
display: 'inline-block',
|
||||
px: 4,
|
||||
py: 2,
|
||||
bg: 'blue.600',
|
||||
color: 'white',
|
||||
borderRadius: 'lg',
|
||||
textDecoration: 'none',
|
||||
fontWeight: 'medium',
|
||||
_hover: { bg: 'blue.500' },
|
||||
})}
|
||||
>
|
||||
Collect Training Data
|
||||
</Link>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Show local data stats */}
|
||||
{samples?.hasData && (
|
||||
<>
|
||||
{/* Image count */}
|
||||
<div className={css({ fontSize: 'xl', fontWeight: 'bold', color: 'gray.100', mb: 3 })}>
|
||||
{samples.totalImages.toLocaleString()} images
|
||||
</div>
|
||||
|
||||
{/* Quality indicator */}
|
||||
<div className={css({ mb: 4 })}>
|
||||
<div className={css({ display: 'flex', justifyContent: 'space-between', mb: 1 })}>
|
||||
<span className={css({ fontSize: 'sm', color: 'gray.400' })}>Quality</span>
|
||||
<span
|
||||
className={css({
|
||||
fontSize: 'sm',
|
||||
fontWeight: 'medium',
|
||||
color: QUALITY_CONFIG[samples.dataQuality].color,
|
||||
})}
|
||||
>
|
||||
{QUALITY_CONFIG[samples.dataQuality].label}
|
||||
</span>
|
||||
</div>
|
||||
<div className={css({ height: '6px', bg: 'gray.700', borderRadius: 'full', overflow: 'hidden' })}>
|
||||
<div
|
||||
className={css({
|
||||
height: '100%',
|
||||
bg: QUALITY_CONFIG[samples.dataQuality].color,
|
||||
borderRadius: 'full',
|
||||
transition: 'width 0.3s ease',
|
||||
})}
|
||||
style={{ width: QUALITY_CONFIG[samples.dataQuality].barWidth }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Digit distribution mini-chart */}
|
||||
<div className={css({ mb: 4 })}>
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500', mb: 2 })}>Distribution</div>
|
||||
<div className={css({ display: 'flex', gap: 1, justifyContent: 'space-between' })}>
|
||||
{Object.entries(samples.digits).map(([digit, data]) => {
|
||||
const maxCount = Math.max(...Object.values(samples.digits).map((d) => d.count))
|
||||
const barHeight = maxCount > 0 ? (data.count / maxCount) * 30 : 0
|
||||
return (
|
||||
<div
|
||||
key={digit}
|
||||
className={css({ display: 'flex', flexDirection: 'column', alignItems: 'center', flex: 1 })}
|
||||
>
|
||||
<div
|
||||
className={css({
|
||||
width: '100%',
|
||||
bg: 'blue.600',
|
||||
borderRadius: 'sm',
|
||||
transition: 'height 0.3s ease',
|
||||
})}
|
||||
style={{ height: `${barHeight}px` }}
|
||||
/>
|
||||
<span className={css({ fontSize: 'xs', color: 'gray.500', mt: 1 })}>{digit}</span>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Ready indicator and continue */}
|
||||
{isReady && !isSyncing && (
|
||||
<div className={css({ mt: 4 })}>
|
||||
<div className={css({ display: 'flex', alignItems: 'center', gap: 2, mb: 3 })}>
|
||||
<span className={css({ color: 'green.400' })}>✓</span>
|
||||
<span className={css({ color: 'green.400', fontSize: 'sm' })}>Ready to train</span>
|
||||
</div>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
onClick={onProgress}
|
||||
className={css({
|
||||
width: '100%',
|
||||
py: 2,
|
||||
bg: 'green.600',
|
||||
color: 'white',
|
||||
borderRadius: 'lg',
|
||||
border: 'none',
|
||||
cursor: 'pointer',
|
||||
fontWeight: 'medium',
|
||||
_hover: { bg: 'green.500' },
|
||||
})}
|
||||
>
|
||||
Continue →
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
'use client'
|
||||
|
||||
import { css } from '../../../../../../../styled-system/css'
|
||||
|
||||
interface ExportCardProps {
|
||||
message: string
|
||||
}
|
||||
|
||||
export function ExportCard({ message }: ExportCardProps) {
|
||||
return (
|
||||
<div className={css({ textAlign: 'center', py: 6 })}>
|
||||
<div className={css({ fontSize: '2xl', mb: 3, animation: 'spin 1s linear infinite' })}>📦</div>
|
||||
<div className={css({ fontSize: 'lg', fontWeight: 'medium', color: 'gray.200', mb: 2 })}>
|
||||
Exporting Model
|
||||
</div>
|
||||
<div className={css({ color: 'gray.400', fontSize: 'sm' })}>
|
||||
{message || 'Converting to TensorFlow.js format...'}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
'use client'
|
||||
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { css } from '../../../../../../../styled-system/css'
|
||||
import type { HardwareInfo } from '../types'
|
||||
|
||||
interface HardwareCardProps {
|
||||
hardwareInfo: HardwareInfo | null
|
||||
hardwareLoading: boolean
|
||||
fetchHardware: () => void
|
||||
onProgress: () => void
|
||||
}
|
||||
|
||||
const AUTO_PROGRESS_DELAY = 2000
|
||||
|
||||
export function HardwareCard({
|
||||
hardwareInfo,
|
||||
hardwareLoading,
|
||||
fetchHardware,
|
||||
onProgress,
|
||||
}: HardwareCardProps) {
|
||||
const [countdown, setCountdown] = useState(AUTO_PROGRESS_DELAY)
|
||||
const timerRef = useRef<NodeJS.Timeout | null>(null)
|
||||
const isReady = hardwareInfo && !hardwareInfo.error && !hardwareLoading
|
||||
|
||||
// Auto-progress countdown
|
||||
useEffect(() => {
|
||||
if (!isReady) {
|
||||
setCountdown(AUTO_PROGRESS_DELAY)
|
||||
return
|
||||
}
|
||||
|
||||
// Start countdown
|
||||
const startTime = Date.now()
|
||||
const tick = () => {
|
||||
const elapsed = Date.now() - startTime
|
||||
const remaining = Math.max(0, AUTO_PROGRESS_DELAY - elapsed)
|
||||
setCountdown(remaining)
|
||||
|
||||
if (remaining <= 0) {
|
||||
onProgress()
|
||||
} else {
|
||||
timerRef.current = setTimeout(tick, 50)
|
||||
}
|
||||
}
|
||||
|
||||
timerRef.current = setTimeout(tick, 50)
|
||||
|
||||
return () => {
|
||||
if (timerRef.current) {
|
||||
clearTimeout(timerRef.current)
|
||||
}
|
||||
}
|
||||
}, [isReady, onProgress])
|
||||
|
||||
const handleSkip = () => {
|
||||
if (timerRef.current) {
|
||||
clearTimeout(timerRef.current)
|
||||
}
|
||||
onProgress()
|
||||
}
|
||||
|
||||
if (hardwareLoading) {
|
||||
return (
|
||||
<div className={css({ textAlign: 'center', py: 6 })}>
|
||||
<div className={css({ fontSize: '2xl', mb: 2, animation: 'spin 1s linear infinite' })}>⚙️</div>
|
||||
<div className={css({ color: 'gray.400' })}>Detecting hardware...</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (hardwareInfo?.error) {
|
||||
return (
|
||||
<div className={css({ textAlign: 'center', py: 4 })}>
|
||||
<div className={css({ fontSize: '2xl', mb: 2 })}>⚠️</div>
|
||||
<div className={css({ color: 'red.400', mb: 2 })}>Hardware setup failed</div>
|
||||
<div className={css({ fontSize: 'sm', color: 'gray.500', mb: 4 })}>
|
||||
{hardwareInfo.error}
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={fetchHardware}
|
||||
className={css({
|
||||
px: 4,
|
||||
py: 2,
|
||||
bg: 'blue.600',
|
||||
color: 'white',
|
||||
borderRadius: 'lg',
|
||||
border: 'none',
|
||||
cursor: 'pointer',
|
||||
_hover: { bg: 'blue.500' },
|
||||
})}
|
||||
>
|
||||
Retry Detection
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (!hardwareInfo) {
|
||||
return (
|
||||
<div className={css({ textAlign: 'center', py: 6 })}>
|
||||
<div className={css({ color: 'gray.500' })}>No hardware detected</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const isGpu = hardwareInfo.deviceType === 'gpu'
|
||||
const progressPercent = ((AUTO_PROGRESS_DELAY - countdown) / AUTO_PROGRESS_DELAY) * 100
|
||||
|
||||
return (
|
||||
<div className={css({ textAlign: 'center' })}>
|
||||
{/* Device Icon */}
|
||||
<div className={css({ fontSize: '3xl', mb: 2 })}>
|
||||
{isGpu ? '⚡' : '💻'}
|
||||
</div>
|
||||
|
||||
{/* Device Name */}
|
||||
<div className={css({ fontSize: 'xl', fontWeight: 'bold', color: 'gray.100', mb: 1 })}>
|
||||
{hardwareInfo.deviceName}
|
||||
</div>
|
||||
|
||||
{/* Device Type Badge */}
|
||||
<div
|
||||
className={css({
|
||||
display: 'inline-block',
|
||||
px: 3,
|
||||
py: 1,
|
||||
borderRadius: 'full',
|
||||
fontSize: 'sm',
|
||||
fontWeight: 'bold',
|
||||
bg: isGpu ? 'green.700' : 'blue.700',
|
||||
color: 'white',
|
||||
mb: 3,
|
||||
})}
|
||||
>
|
||||
{hardwareInfo.deviceType.toUpperCase()}
|
||||
{isGpu && ' Acceleration'}
|
||||
</div>
|
||||
|
||||
{/* Hint */}
|
||||
<div className={css({ fontSize: 'sm', color: 'gray.400', mb: 4 })}>
|
||||
{isGpu ? 'Training will be fast!' : 'CPU training available'}
|
||||
</div>
|
||||
|
||||
{/* TensorFlow version */}
|
||||
{typeof hardwareInfo.details?.tensorflowVersion === 'string' && (
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500', mb: 4 })}>
|
||||
TensorFlow {hardwareInfo.details.tensorflowVersion}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Auto-progress bar */}
|
||||
<div className={css({ mb: 3 })}>
|
||||
<div className={css({ height: '4px', bg: 'gray.700', borderRadius: 'full', overflow: 'hidden' })}>
|
||||
<div
|
||||
className={css({
|
||||
height: '100%',
|
||||
bg: 'blue.500',
|
||||
borderRadius: 'full',
|
||||
transition: 'width 0.05s linear',
|
||||
})}
|
||||
style={{ width: `${progressPercent}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Continue button */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleSkip}
|
||||
className={css({
|
||||
width: '100%',
|
||||
py: 2,
|
||||
bg: 'blue.600',
|
||||
color: 'white',
|
||||
borderRadius: 'lg',
|
||||
border: 'none',
|
||||
cursor: 'pointer',
|
||||
fontWeight: 'medium',
|
||||
_hover: { bg: 'blue.500' },
|
||||
})}
|
||||
>
|
||||
Continue →
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
'use client'
|
||||
|
||||
import { css } from '../../../../../../../styled-system/css'
|
||||
import type { DatasetInfo } from '../types'
|
||||
|
||||
interface LoadingCardProps {
|
||||
datasetInfo: DatasetInfo | null
|
||||
message: string
|
||||
}
|
||||
|
||||
export function LoadingCard({ datasetInfo, message }: LoadingCardProps) {
|
||||
return (
|
||||
<div className={css({ textAlign: 'center', py: 6 })}>
|
||||
<div className={css({ fontSize: '2xl', mb: 3, animation: 'spin 1s linear infinite' })}>📥</div>
|
||||
<div className={css({ fontSize: 'lg', fontWeight: 'medium', color: 'gray.200', mb: 2 })}>
|
||||
Loading Dataset
|
||||
</div>
|
||||
<div className={css({ color: 'gray.400', fontSize: 'sm', mb: 3 })}>
|
||||
{message || 'Loading training images...'}
|
||||
</div>
|
||||
|
||||
{datasetInfo && (
|
||||
<div
|
||||
className={css({
|
||||
display: 'inline-block',
|
||||
px: 3,
|
||||
py: 1.5,
|
||||
bg: 'gray.700',
|
||||
borderRadius: 'lg',
|
||||
fontSize: 'sm',
|
||||
})}
|
||||
>
|
||||
<span className={css({ color: 'blue.400', fontWeight: 'bold' })}>
|
||||
{datasetInfo.total_images.toLocaleString()}
|
||||
</span>
|
||||
<span className={css({ color: 'gray.400' })}> images loaded</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
'use client'
|
||||
|
||||
import { css } from '../../../../../../../styled-system/css'
|
||||
import type { TrainingResult } from '../types'
|
||||
|
||||
interface ResultsCardProps {
|
||||
result: TrainingResult | null
|
||||
error: string | null
|
||||
onTrainAgain: () => void
|
||||
}
|
||||
|
||||
export function ResultsCard({ result, error, onTrainAgain }: ResultsCardProps) {
|
||||
if (error) {
|
||||
return (
|
||||
<div className={css({ textAlign: 'center', py: 4 })}>
|
||||
<div className={css({ fontSize: '3xl', mb: 3 })}>❌</div>
|
||||
<div className={css({ fontSize: 'lg', fontWeight: 'bold', color: 'red.400', mb: 2 })}>
|
||||
Training Failed
|
||||
</div>
|
||||
<div className={css({ color: 'gray.400', fontSize: 'sm', mb: 4, maxWidth: '280px', mx: 'auto' })}>
|
||||
{error}
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={onTrainAgain}
|
||||
className={css({
|
||||
px: 6,
|
||||
py: 3,
|
||||
bg: 'blue.600',
|
||||
color: 'white',
|
||||
borderRadius: 'lg',
|
||||
border: 'none',
|
||||
cursor: 'pointer',
|
||||
fontWeight: 'bold',
|
||||
fontSize: 'md',
|
||||
_hover: { bg: 'blue.500' },
|
||||
})}
|
||||
>
|
||||
Try Again
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (!result) {
|
||||
return (
|
||||
<div className={css({ textAlign: 'center', py: 6 })}>
|
||||
<div className={css({ fontSize: '2xl', mb: 3, animation: 'spin 1s linear infinite' })}>⏳</div>
|
||||
<div className={css({ color: 'gray.400' })}>Waiting for results...</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const accuracy = result.final_accuracy ?? 0
|
||||
|
||||
return (
|
||||
<div className={css({ textAlign: 'center' })}>
|
||||
{/* Success icon */}
|
||||
<div className={css({ fontSize: '3xl', mb: 2 })}>🎉</div>
|
||||
|
||||
{/* Title */}
|
||||
<div className={css({ fontSize: 'lg', fontWeight: 'bold', color: 'green.400', mb: 3 })}>
|
||||
Training Complete!
|
||||
</div>
|
||||
|
||||
{/* Main accuracy */}
|
||||
<div
|
||||
className={css({
|
||||
fontSize: '4xl',
|
||||
fontWeight: 'bold',
|
||||
color: 'green.400',
|
||||
mb: 0.5,
|
||||
})}
|
||||
>
|
||||
{(accuracy * 100).toFixed(1)}%
|
||||
</div>
|
||||
<div className={css({ fontSize: 'sm', color: 'gray.500', mb: 4 })}>
|
||||
Final Accuracy
|
||||
</div>
|
||||
|
||||
{/* Stats grid */}
|
||||
<div
|
||||
className={css({
|
||||
display: 'grid',
|
||||
gridTemplateColumns: 'repeat(3, 1fr)',
|
||||
gap: 2,
|
||||
p: 3,
|
||||
bg: 'gray.900',
|
||||
borderRadius: 'lg',
|
||||
fontSize: 'sm',
|
||||
mb: 4,
|
||||
})}
|
||||
>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Epochs</div>
|
||||
<div className={css({ fontFamily: 'mono', color: 'gray.300', fontWeight: 'medium' })}>
|
||||
{result.epochs_trained ?? '—'}
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Final Loss</div>
|
||||
<div className={css({ fontFamily: 'mono', color: 'gray.300', fontWeight: 'medium' })}>
|
||||
{result.final_loss?.toFixed(4) ?? '—'}
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600', fontSize: 'xs' })}>Model</div>
|
||||
<div className={css({ color: 'green.400', fontWeight: 'medium' })}>
|
||||
{result.tfjs_exported ? '✓ Saved' : '—'}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Model path */}
|
||||
{result.tfjs_exported && (
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500', mb: 4 })}>
|
||||
Model exported and ready to use
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Train again button */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={onTrainAgain}
|
||||
className={css({
|
||||
px: 6,
|
||||
py: 3,
|
||||
bg: 'blue.600',
|
||||
color: 'white',
|
||||
borderRadius: 'lg',
|
||||
border: 'none',
|
||||
cursor: 'pointer',
|
||||
fontWeight: 'bold',
|
||||
fontSize: 'md',
|
||||
_hover: { bg: 'blue.500' },
|
||||
})}
|
||||
>
|
||||
Train Again
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
'use client'
|
||||
|
||||
import { css } from '../../../../../../../styled-system/css'
|
||||
|
||||
interface SetupCardProps {
|
||||
message: string
|
||||
}
|
||||
|
||||
export function SetupCard({ message }: SetupCardProps) {
|
||||
return (
|
||||
<div className={css({ textAlign: 'center', py: 6 })}>
|
||||
<div className={css({ fontSize: '2xl', mb: 3, animation: 'spin 1s linear infinite' })}>⚙️</div>
|
||||
<div className={css({ fontSize: 'lg', fontWeight: 'medium', color: 'gray.200', mb: 2 })}>
|
||||
Setting Up
|
||||
</div>
|
||||
<div className={css({ color: 'gray.400', fontSize: 'sm' })}>
|
||||
{message || 'Initializing Python environment...'}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
'use client'
|
||||
|
||||
import { css } from '../../../../../../../styled-system/css'
|
||||
import type { EpochData } from '../types'
|
||||
|
||||
interface TrainingCardProps {
|
||||
currentEpoch: EpochData | null
|
||||
totalEpochs: number
|
||||
bestAccuracy: number
|
||||
onCancel: () => void
|
||||
}
|
||||
|
||||
export function TrainingCard({
|
||||
currentEpoch,
|
||||
totalEpochs,
|
||||
bestAccuracy,
|
||||
onCancel,
|
||||
}: TrainingCardProps) {
|
||||
if (!currentEpoch) {
|
||||
return (
|
||||
<div className={css({ textAlign: 'center', py: 6 })}>
|
||||
<div className={css({ fontSize: '2xl', mb: 3, animation: 'spin 1s linear infinite' })}>🏋️</div>
|
||||
<div className={css({ color: 'gray.400' })}>Starting training...</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const progressPercent = Math.round((currentEpoch.epoch / currentEpoch.total_epochs) * 100)
|
||||
|
||||
return (
|
||||
<div className={css({ textAlign: 'center' })}>
|
||||
{/* Epoch counter */}
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500', mb: 1 })}>
|
||||
Epoch {currentEpoch.epoch} of {totalEpochs}
|
||||
</div>
|
||||
|
||||
{/* Main accuracy */}
|
||||
<div
|
||||
className={css({
|
||||
fontSize: '3xl',
|
||||
fontWeight: 'bold',
|
||||
color: 'green.400',
|
||||
mb: 0.5,
|
||||
})}
|
||||
>
|
||||
{(currentEpoch.val_accuracy * 100).toFixed(1)}%
|
||||
</div>
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.500', mb: 3 })}>
|
||||
Validation Accuracy
|
||||
</div>
|
||||
|
||||
{/* Progress bar */}
|
||||
<div className={css({ mb: 3 })}>
|
||||
<div
|
||||
className={css({
|
||||
height: '8px',
|
||||
bg: 'gray.700',
|
||||
borderRadius: 'full',
|
||||
overflow: 'hidden',
|
||||
})}
|
||||
>
|
||||
<div
|
||||
className={css({
|
||||
height: '100%',
|
||||
bg: 'blue.500',
|
||||
borderRadius: 'full',
|
||||
transition: 'width 0.3s ease',
|
||||
})}
|
||||
style={{ width: `${progressPercent}%` }}
|
||||
/>
|
||||
</div>
|
||||
<div className={css({ fontSize: 'xs', color: 'gray.600', mt: 1 })}>
|
||||
{progressPercent}% complete
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Metrics grid */}
|
||||
<div
|
||||
className={css({
|
||||
display: 'grid',
|
||||
gridTemplateColumns: 'repeat(3, 1fr)',
|
||||
gap: 2,
|
||||
p: 2,
|
||||
bg: 'gray.900',
|
||||
borderRadius: 'lg',
|
||||
fontSize: 'xs',
|
||||
mb: 4,
|
||||
})}
|
||||
>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600' })}>Loss</div>
|
||||
<div className={css({ fontFamily: 'mono', color: 'gray.300' })}>
|
||||
{currentEpoch.loss.toFixed(4)}
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600' })}>Train Acc</div>
|
||||
<div className={css({ fontFamily: 'mono', color: 'gray.300' })}>
|
||||
{(currentEpoch.accuracy * 100).toFixed(1)}%
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<div className={css({ color: 'gray.600' })}>Best</div>
|
||||
<div className={css({ fontFamily: 'mono', color: 'green.400' })}>
|
||||
{(bestAccuracy * 100).toFixed(1)}%
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Cancel button */}
|
||||
<button
|
||||
type="button"
|
||||
onClick={onCancel}
|
||||
className={css({
|
||||
px: 4,
|
||||
py: 2,
|
||||
bg: 'transparent',
|
||||
color: 'gray.500',
|
||||
fontSize: 'sm',
|
||||
borderRadius: 'lg',
|
||||
border: '1px solid',
|
||||
borderColor: 'gray.700',
|
||||
cursor: 'pointer',
|
||||
_hover: { borderColor: 'gray.600', color: 'gray.400' },
|
||||
})}
|
||||
>
|
||||
Cancel Training
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
// Phase and Card identifiers
|
||||
export type PhaseId = 'preparation' | 'training' | 'results'
|
||||
export type CardId =
|
||||
| 'data'
|
||||
| 'hardware'
|
||||
| 'config'
|
||||
| 'setup'
|
||||
| 'loading'
|
||||
| 'training'
|
||||
| 'export'
|
||||
| 'results'
|
||||
|
||||
// Card state relative to current position
|
||||
export type CardPosition = 'done' | 'current' | 'upcoming'
|
||||
|
||||
// Phase status
|
||||
export type PhaseStatus = 'done' | 'current' | 'upcoming'
|
||||
|
||||
// Phase definition
|
||||
export interface PhaseDefinition {
|
||||
id: PhaseId
|
||||
title: string
|
||||
cards: CardId[]
|
||||
}
|
||||
|
||||
// All phases in order
|
||||
export const PHASES: PhaseDefinition[] = [
|
||||
{
|
||||
id: 'preparation',
|
||||
title: 'Preparation',
|
||||
cards: ['data', 'hardware', 'config'],
|
||||
},
|
||||
{
|
||||
id: 'training',
|
||||
title: 'Training',
|
||||
cards: ['setup', 'loading', 'training', 'export'],
|
||||
},
|
||||
{
|
||||
id: 'results',
|
||||
title: 'Results',
|
||||
cards: ['results'],
|
||||
},
|
||||
]
|
||||
|
||||
// Card metadata
|
||||
export interface CardDefinition {
|
||||
id: CardId
|
||||
title: string
|
||||
icon: string
|
||||
autoProgress: boolean
|
||||
autoProgressDelay?: number // ms
|
||||
}
|
||||
|
||||
export const CARDS: Record<CardId, CardDefinition> = {
|
||||
data: {
|
||||
id: 'data',
|
||||
title: 'Training Data',
|
||||
icon: '📊',
|
||||
autoProgress: true,
|
||||
autoProgressDelay: 2000,
|
||||
},
|
||||
hardware: {
|
||||
id: 'hardware',
|
||||
title: 'Hardware',
|
||||
icon: '🔧',
|
||||
autoProgress: true,
|
||||
autoProgressDelay: 2000,
|
||||
},
|
||||
config: {
|
||||
id: 'config',
|
||||
title: 'Configuration',
|
||||
icon: '⚙️',
|
||||
autoProgress: false,
|
||||
},
|
||||
setup: {
|
||||
id: 'setup',
|
||||
title: 'Setup',
|
||||
icon: '🔄',
|
||||
autoProgress: true, // Event-driven
|
||||
},
|
||||
loading: {
|
||||
id: 'loading',
|
||||
title: 'Loading',
|
||||
icon: '📥',
|
||||
autoProgress: true, // Event-driven
|
||||
},
|
||||
training: {
|
||||
id: 'training',
|
||||
title: 'Training',
|
||||
icon: '🏋️',
|
||||
autoProgress: true, // Event-driven (when epochs complete)
|
||||
},
|
||||
export: {
|
||||
id: 'export',
|
||||
title: 'Export',
|
||||
icon: '📦',
|
||||
autoProgress: true, // Event-driven
|
||||
},
|
||||
results: {
|
||||
id: 'results',
|
||||
title: 'Results',
|
||||
icon: '🎉',
|
||||
autoProgress: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Card summaries for done state
|
||||
export interface CardSummary {
|
||||
label: string
|
||||
value: string
|
||||
}
|
||||
|
||||
// Wizard state
|
||||
export interface WizardState {
|
||||
currentPhaseIndex: number
|
||||
currentCardIndex: number
|
||||
}
|
||||
|
||||
// Data types (re-exported from existing)
|
||||
export interface DigitSample {
|
||||
count: number
|
||||
samplePath: string | null
|
||||
tilePaths: string[]
|
||||
}
|
||||
|
||||
export interface SamplesData {
|
||||
digits: Record<number, DigitSample>
|
||||
totalImages: number
|
||||
hasData: boolean
|
||||
dataQuality: 'none' | 'insufficient' | 'minimal' | 'good' | 'excellent'
|
||||
}
|
||||
|
||||
export interface HardwareInfo {
|
||||
available: boolean
|
||||
device: string
|
||||
deviceName: string
|
||||
deviceType: string
|
||||
details: Record<string, unknown>
|
||||
error: string | null
|
||||
hint?: string
|
||||
}
|
||||
|
||||
export interface EpochData {
|
||||
epoch: number
|
||||
total_epochs: number
|
||||
loss: number
|
||||
accuracy: number
|
||||
val_loss: number
|
||||
val_accuracy: number
|
||||
}
|
||||
|
||||
export interface DatasetInfo {
|
||||
total_images: number
|
||||
digit_counts: Record<number, number>
|
||||
}
|
||||
|
||||
export interface TrainingResult {
|
||||
final_accuracy: number
|
||||
final_loss: number
|
||||
epochs_trained: number
|
||||
output_dir: string
|
||||
tfjs_exported: boolean
|
||||
}
|
||||
|
||||
export interface TrainingConfig {
|
||||
epochs: number
|
||||
batchSize: number
|
||||
validationSplit: number
|
||||
augmentation: boolean
|
||||
}
|
||||
|
||||
// Server-side training phase (from SSE events)
|
||||
export type ServerPhase = 'idle' | 'setup' | 'loading' | 'training' | 'exporting' | 'complete' | 'error'
|
||||
|
||||
// Helper to get phase index
|
||||
export function getPhaseIndex(phaseId: PhaseId): number {
|
||||
return PHASES.findIndex((p) => p.id === phaseId)
|
||||
}
|
||||
|
||||
// Helper to get card index within a phase
|
||||
export function getCardIndexInPhase(phaseId: PhaseId, cardId: CardId): number {
|
||||
const phase = PHASES.find((p) => p.id === phaseId)
|
||||
return phase?.cards.indexOf(cardId) ?? -1
|
||||
}
|
||||
|
||||
// Helper to find which phase a card belongs to
|
||||
export function getPhaseForCard(cardId: CardId): PhaseId | null {
|
||||
for (const phase of PHASES) {
|
||||
if (phase.cards.includes(cardId)) {
|
||||
return phase.id
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
// Map server phase to wizard position
|
||||
export function serverPhaseToWizardPosition(
|
||||
serverPhase: ServerPhase
|
||||
): { phaseIndex: number; cardIndex: number } {
|
||||
switch (serverPhase) {
|
||||
case 'idle':
|
||||
return { phaseIndex: 0, cardIndex: 0 } // Start at data card
|
||||
case 'setup':
|
||||
return { phaseIndex: 1, cardIndex: 0 } // Training phase, setup card
|
||||
case 'loading':
|
||||
return { phaseIndex: 1, cardIndex: 1 } // Training phase, loading card
|
||||
case 'training':
|
||||
return { phaseIndex: 1, cardIndex: 2 } // Training phase, training card
|
||||
case 'exporting':
|
||||
return { phaseIndex: 1, cardIndex: 3 } // Training phase, export card
|
||||
case 'complete':
|
||||
return { phaseIndex: 2, cardIndex: 0 } // Results phase
|
||||
case 'error':
|
||||
return { phaseIndex: 2, cardIndex: 0 } // Show error in results phase
|
||||
default:
|
||||
return { phaseIndex: 0, cardIndex: 0 }
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user