feat: add LLM client package and worksheet parsing infrastructure
Part A: @soroban/llm-client package - Multi-provider support (OpenAI, Anthropic) via env vars - Zod schema validation for structured LLM responses - Retry loop with validation error feedback in prompt - Progress indication hooks for UI feedback - Vision support for image analysis Part B: Worksheet parsing feature - Zod schemas for parsed worksheet problems - LLM prompt builder for abacus workbook images - Parser using llm.vision() with retry logic - Session converter to create SlotResults for BKT - Database migration for parsing columns - API routes: /parse, /review, /approve workflow 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -462,7 +462,9 @@
|
||||
"Bash(apps/web/src/hooks/__tests__/useRemoteCameraDesktop.test.ts )",
|
||||
"Bash(apps/web/src/hooks/__tests__/useRemoteCameraPhone.test.ts )",
|
||||
"Bash(apps/web/src/lib/remote-camera/__tests__/)",
|
||||
"Bash(packages/abacus-react/CHANGELOG.md )"
|
||||
"Bash(packages/abacus-react/CHANGELOG.md )",
|
||||
"WebFetch(domain:zod.dev)",
|
||||
"Bash(npm view:*)"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
|
||||
37
apps/web/drizzle/0055_add_attachment_parsing.sql
Normal file
37
apps/web/drizzle/0055_add_attachment_parsing.sql
Normal file
@@ -0,0 +1,37 @@
|
||||
-- Add LLM-powered worksheet parsing columns to practice_attachments
|
||||
-- These columns support the workflow: parse → review → approve → create session
|
||||
|
||||
-- Parsing workflow status
|
||||
ALTER TABLE `practice_attachments` ADD COLUMN `parsing_status` text;
|
||||
--> statement-breakpoint
|
||||
|
||||
-- When parsing completed (ISO timestamp)
|
||||
ALTER TABLE `practice_attachments` ADD COLUMN `parsed_at` text;
|
||||
--> statement-breakpoint
|
||||
|
||||
-- Error message if parsing failed
|
||||
ALTER TABLE `practice_attachments` ADD COLUMN `parsing_error` text;
|
||||
--> statement-breakpoint
|
||||
|
||||
-- Raw LLM parsing result (JSON) - before user corrections
|
||||
ALTER TABLE `practice_attachments` ADD COLUMN `raw_parsing_result` text;
|
||||
--> statement-breakpoint
|
||||
|
||||
-- Approved result (JSON) - after user corrections
|
||||
ALTER TABLE `practice_attachments` ADD COLUMN `approved_result` text;
|
||||
--> statement-breakpoint
|
||||
|
||||
-- Overall confidence score from LLM (0-1)
|
||||
ALTER TABLE `practice_attachments` ADD COLUMN `confidence_score` real;
|
||||
--> statement-breakpoint
|
||||
|
||||
-- True if any problems need manual review
|
||||
ALTER TABLE `practice_attachments` ADD COLUMN `needs_review` integer;
|
||||
--> statement-breakpoint
|
||||
|
||||
-- True if a session was created from this parsed worksheet
|
||||
ALTER TABLE `practice_attachments` ADD COLUMN `session_created` integer;
|
||||
--> statement-breakpoint
|
||||
|
||||
-- Reference to the session created from this parsing
|
||||
ALTER TABLE `practice_attachments` ADD COLUMN `created_session_id` text REFERENCES session_plans(id) ON DELETE SET NULL;
|
||||
@@ -386,6 +386,13 @@
|
||||
"when": 1767240895813,
|
||||
"tag": "0054_new_mathemanic",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 55,
|
||||
"version": "6",
|
||||
"when": 1767398400000,
|
||||
"tag": "0055_add_attachment_parsing",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@
|
||||
"@react-three/fiber": "^8.17.0",
|
||||
"@soroban/abacus-react": "workspace:*",
|
||||
"@soroban/core": "workspace:*",
|
||||
"@soroban/llm-client": "workspace:*",
|
||||
"@soroban/templates": "workspace:*",
|
||||
"@strudel/soundfonts": "^1.2.6",
|
||||
"@strudel/web": "^1.2.6",
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
/**
|
||||
* API route for approving parsed worksheet results and creating a practice session
|
||||
*
|
||||
* POST /api/curriculum/[playerId]/attachments/[attachmentId]/approve
|
||||
* - Approves the parsing result
|
||||
* - Creates a practice session from the parsed problems
|
||||
* - Links the session back to the attachment
|
||||
*/
|
||||
|
||||
import { NextResponse } from 'next/server'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { createId } from '@paralleldrive/cuid2'
|
||||
import { db } from '@/db'
|
||||
import { practiceAttachments } from '@/db/schema/practice-attachments'
|
||||
import {
|
||||
sessionPlans,
|
||||
type SessionStatus,
|
||||
type SessionPart,
|
||||
type SessionSummary,
|
||||
type SlotResult,
|
||||
} from '@/db/schema/session-plans'
|
||||
import { canPerformAction } from '@/lib/classroom'
|
||||
import { getDbUserId } from '@/lib/viewer'
|
||||
import {
|
||||
convertToSlotResults,
|
||||
computeParsingStats,
|
||||
} from '@/lib/worksheet-parsing'
|
||||
|
||||
interface RouteParams {
|
||||
params: Promise<{ playerId: string; attachmentId: string }>
|
||||
}
|
||||
|
||||
/**
|
||||
* POST - Approve parsing and create practice session
|
||||
*/
|
||||
export async function POST(_request: Request, { params }: RouteParams) {
|
||||
try {
|
||||
const { playerId, attachmentId } = await params
|
||||
|
||||
if (!playerId || !attachmentId) {
|
||||
return NextResponse.json({ error: 'Player ID and Attachment ID required' }, { status: 400 })
|
||||
}
|
||||
|
||||
// Authorization check
|
||||
const userId = await getDbUserId()
|
||||
const canApprove = await canPerformAction(userId, playerId, 'start-session')
|
||||
if (!canApprove) {
|
||||
return NextResponse.json({ error: 'Not authorized' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Get attachment record
|
||||
const attachment = await db
|
||||
.select()
|
||||
.from(practiceAttachments)
|
||||
.where(eq(practiceAttachments.id, attachmentId))
|
||||
.get()
|
||||
|
||||
if (!attachment) {
|
||||
return NextResponse.json({ error: 'Attachment not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
if (attachment.playerId !== playerId) {
|
||||
return NextResponse.json({ error: 'Attachment not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
// Check if already created a session
|
||||
if (attachment.sessionCreated) {
|
||||
return NextResponse.json({
|
||||
error: 'Session already created from this attachment',
|
||||
sessionId: attachment.createdSessionId,
|
||||
}, { status: 400 })
|
||||
}
|
||||
|
||||
// Get the parsing result to convert (prefer approved result, fall back to raw)
|
||||
const parsingResult = attachment.approvedResult ?? attachment.rawParsingResult
|
||||
if (!parsingResult) {
|
||||
return NextResponse.json({
|
||||
error: 'No parsing results available. Parse the worksheet first.',
|
||||
}, { status: 400 })
|
||||
}
|
||||
|
||||
// Convert to slot results
|
||||
const conversionResult = convertToSlotResults(parsingResult, {
|
||||
partNumber: 1,
|
||||
source: 'practice',
|
||||
})
|
||||
|
||||
if (conversionResult.slotResults.length === 0) {
|
||||
return NextResponse.json({
|
||||
error: 'No valid problems to create session from',
|
||||
}, { status: 400 })
|
||||
}
|
||||
|
||||
// Create the session with completed status
|
||||
const sessionId = createId()
|
||||
const now = new Date()
|
||||
|
||||
// Add timestamps to slot results
|
||||
const slotResultsWithTimestamps: SlotResult[] = conversionResult.slotResults.map((result) => ({
|
||||
...result,
|
||||
timestamp: now,
|
||||
}))
|
||||
|
||||
// Calculate session summary from results
|
||||
const correctCount = slotResultsWithTimestamps.filter((r) => r.isCorrect).length
|
||||
const totalCount = slotResultsWithTimestamps.length
|
||||
|
||||
// Session status for parsed worksheets
|
||||
const sessionStatus: SessionStatus = 'completed'
|
||||
|
||||
// Build minimal session part (offline worksheets are single-part)
|
||||
const offlinePart: SessionPart = {
|
||||
partNumber: 1,
|
||||
type: 'abacus', // Worksheet problems are solved on physical abacus
|
||||
format: 'vertical', // Most worksheets are vertical format
|
||||
useAbacus: true, // Assume physical abacus was used
|
||||
slots: slotResultsWithTimestamps.map((result, idx) => ({
|
||||
index: idx,
|
||||
purpose: 'review' as const,
|
||||
constraints: {},
|
||||
problem: result.problem,
|
||||
})),
|
||||
estimatedMinutes: 0, // Unknown for offline work
|
||||
}
|
||||
|
||||
// Build session summary
|
||||
const sessionSummary: SessionSummary = {
|
||||
focusDescription: 'Worksheet practice (offline)',
|
||||
totalProblemCount: totalCount,
|
||||
estimatedMinutes: 0,
|
||||
parts: [
|
||||
{
|
||||
partNumber: 1,
|
||||
type: 'abacus',
|
||||
description: 'Worksheet Problems',
|
||||
problemCount: totalCount,
|
||||
estimatedMinutes: 0,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
// Create the session
|
||||
await db.insert(sessionPlans).values({
|
||||
id: sessionId,
|
||||
playerId,
|
||||
status: sessionStatus,
|
||||
|
||||
// Required setup parameters
|
||||
targetDurationMinutes: 0, // Unknown for offline
|
||||
estimatedProblemCount: totalCount,
|
||||
avgTimePerProblemSeconds: 0, // Unknown for offline
|
||||
|
||||
// Plan content
|
||||
parts: [offlinePart],
|
||||
summary: sessionSummary,
|
||||
masteredSkillIds: [], // Not tracked for offline
|
||||
|
||||
// Session state
|
||||
currentPartIndex: 0,
|
||||
currentSlotIndex: totalCount - 1, // Completed
|
||||
|
||||
// Results
|
||||
results: slotResultsWithTimestamps,
|
||||
|
||||
// Timestamps
|
||||
createdAt: now,
|
||||
approvedAt: now,
|
||||
startedAt: now,
|
||||
completedAt: now,
|
||||
})
|
||||
|
||||
// Update attachment to mark session as created
|
||||
await db
|
||||
.update(practiceAttachments)
|
||||
.set({
|
||||
parsingStatus: 'approved',
|
||||
sessionCreated: true,
|
||||
createdSessionId: sessionId,
|
||||
})
|
||||
.where(eq(practiceAttachments.id, attachmentId))
|
||||
|
||||
// Compute final stats
|
||||
const stats = computeParsingStats(parsingResult)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
sessionId,
|
||||
problemCount: totalCount,
|
||||
correctCount,
|
||||
accuracy: totalCount > 0 ? correctCount / totalCount : null,
|
||||
skillsExercised: conversionResult.skillsExercised,
|
||||
stats,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Error approving and creating session:', error)
|
||||
return NextResponse.json({ error: 'Failed to approve and create session' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,207 @@
|
||||
/**
|
||||
* API route for LLM-powered worksheet parsing
|
||||
*
|
||||
* POST /api/curriculum/[playerId]/attachments/[attachmentId]/parse
|
||||
* - Start parsing the attachment image
|
||||
* - Returns immediately, polling via GET for status
|
||||
*
|
||||
* GET /api/curriculum/[playerId]/attachments/[attachmentId]/parse
|
||||
* - Get current parsing status and results
|
||||
*/
|
||||
|
||||
import { readFile } from 'fs/promises'
|
||||
import { NextResponse } from 'next/server'
|
||||
import { join } from 'path'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { db } from '@/db'
|
||||
import { practiceAttachments, type ParsingStatus } from '@/db/schema/practice-attachments'
|
||||
import { canPerformAction } from '@/lib/classroom'
|
||||
import { getDbUserId } from '@/lib/viewer'
|
||||
import {
|
||||
parseWorksheetImage,
|
||||
computeParsingStats,
|
||||
type WorksheetParsingResult,
|
||||
} from '@/lib/worksheet-parsing'
|
||||
|
||||
interface RouteParams {
|
||||
params: Promise<{ playerId: string; attachmentId: string }>
|
||||
}
|
||||
|
||||
/**
|
||||
* POST - Start parsing the attachment
|
||||
*/
|
||||
export async function POST(_request: Request, { params }: RouteParams) {
|
||||
try {
|
||||
const { playerId, attachmentId } = await params
|
||||
|
||||
if (!playerId || !attachmentId) {
|
||||
return NextResponse.json({ error: 'Player ID and Attachment ID required' }, { status: 400 })
|
||||
}
|
||||
|
||||
// Authorization check
|
||||
const userId = await getDbUserId()
|
||||
const canParse = await canPerformAction(userId, playerId, 'start-session')
|
||||
if (!canParse) {
|
||||
return NextResponse.json({ error: 'Not authorized' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Get attachment record
|
||||
const attachment = await db
|
||||
.select()
|
||||
.from(practiceAttachments)
|
||||
.where(eq(practiceAttachments.id, attachmentId))
|
||||
.get()
|
||||
|
||||
if (!attachment) {
|
||||
return NextResponse.json({ error: 'Attachment not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
if (attachment.playerId !== playerId) {
|
||||
return NextResponse.json({ error: 'Attachment not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
// Check if already processing
|
||||
if (attachment.parsingStatus === 'processing') {
|
||||
return NextResponse.json({
|
||||
status: 'processing',
|
||||
message: 'Parsing already in progress',
|
||||
})
|
||||
}
|
||||
|
||||
// Update status to processing
|
||||
await db
|
||||
.update(practiceAttachments)
|
||||
.set({
|
||||
parsingStatus: 'processing',
|
||||
parsingError: null,
|
||||
})
|
||||
.where(eq(practiceAttachments.id, attachmentId))
|
||||
|
||||
// Read the image file
|
||||
const uploadDir = join(process.cwd(), 'data', 'uploads', 'players', playerId)
|
||||
const filepath = join(uploadDir, attachment.filename)
|
||||
const imageBuffer = await readFile(filepath)
|
||||
const base64Image = imageBuffer.toString('base64')
|
||||
const mimeType = attachment.mimeType || 'image/jpeg'
|
||||
const imageDataUrl = `data:${mimeType};base64,${base64Image}`
|
||||
|
||||
try {
|
||||
// Parse the worksheet
|
||||
const result = await parseWorksheetImage(imageDataUrl, {
|
||||
maxRetries: 2,
|
||||
})
|
||||
|
||||
const parsingResult = result.data
|
||||
const stats = computeParsingStats(parsingResult)
|
||||
|
||||
// Determine status based on confidence
|
||||
const status: ParsingStatus = parsingResult.needsReview ? 'needs_review' : 'approved'
|
||||
|
||||
// Save results to database
|
||||
await db
|
||||
.update(practiceAttachments)
|
||||
.set({
|
||||
parsingStatus: status,
|
||||
parsedAt: new Date().toISOString(),
|
||||
rawParsingResult: parsingResult,
|
||||
confidenceScore: parsingResult.overallConfidence,
|
||||
needsReview: parsingResult.needsReview,
|
||||
parsingError: null,
|
||||
})
|
||||
.where(eq(practiceAttachments.id, attachmentId))
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
status,
|
||||
result: parsingResult,
|
||||
stats,
|
||||
attempts: result.attempts,
|
||||
usage: result.usage,
|
||||
})
|
||||
} catch (parseError) {
|
||||
const errorMessage = parseError instanceof Error ? parseError.message : 'Unknown parsing error'
|
||||
console.error('Worksheet parsing error:', parseError)
|
||||
|
||||
// Update status to failed
|
||||
await db
|
||||
.update(practiceAttachments)
|
||||
.set({
|
||||
parsingStatus: 'failed',
|
||||
parsingError: errorMessage,
|
||||
})
|
||||
.where(eq(practiceAttachments.id, attachmentId))
|
||||
|
||||
return NextResponse.json({
|
||||
success: false,
|
||||
status: 'failed',
|
||||
error: errorMessage,
|
||||
}, { status: 500 })
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error starting parse:', error)
|
||||
return NextResponse.json({ error: 'Failed to start parsing' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* GET - Get parsing status and results
|
||||
*/
|
||||
export async function GET(_request: Request, { params }: RouteParams) {
|
||||
try {
|
||||
const { playerId, attachmentId } = await params
|
||||
|
||||
if (!playerId || !attachmentId) {
|
||||
return NextResponse.json({ error: 'Player ID and Attachment ID required' }, { status: 400 })
|
||||
}
|
||||
|
||||
// Authorization check
|
||||
const userId = await getDbUserId()
|
||||
const canView = await canPerformAction(userId, playerId, 'view')
|
||||
if (!canView) {
|
||||
return NextResponse.json({ error: 'Not authorized' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Get attachment record
|
||||
const attachment = await db
|
||||
.select()
|
||||
.from(practiceAttachments)
|
||||
.where(eq(practiceAttachments.id, attachmentId))
|
||||
.get()
|
||||
|
||||
if (!attachment) {
|
||||
return NextResponse.json({ error: 'Attachment not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
if (attachment.playerId !== playerId) {
|
||||
return NextResponse.json({ error: 'Attachment not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
// Build response based on status
|
||||
const response: {
|
||||
status: ParsingStatus | null
|
||||
parsedAt: string | null
|
||||
result: WorksheetParsingResult | null
|
||||
error: string | null
|
||||
needsReview: boolean
|
||||
confidenceScore: number | null
|
||||
stats?: ReturnType<typeof computeParsingStats>
|
||||
} = {
|
||||
status: attachment.parsingStatus,
|
||||
parsedAt: attachment.parsedAt,
|
||||
result: attachment.rawParsingResult,
|
||||
error: attachment.parsingError,
|
||||
needsReview: attachment.needsReview === true,
|
||||
confidenceScore: attachment.confidenceScore,
|
||||
}
|
||||
|
||||
// Add stats if we have results
|
||||
if (attachment.rawParsingResult) {
|
||||
response.stats = computeParsingStats(attachment.rawParsingResult)
|
||||
}
|
||||
|
||||
return NextResponse.json(response)
|
||||
} catch (error) {
|
||||
console.error('Error getting parse status:', error)
|
||||
return NextResponse.json({ error: 'Failed to get parsing status' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
/**
|
||||
* API route for reviewing and correcting parsed worksheet results
|
||||
*
|
||||
* PATCH /api/curriculum/[playerId]/attachments/[attachmentId]/review
|
||||
* - Submit user corrections to parsed problems
|
||||
* - Updates the parsing result with corrections
|
||||
*/
|
||||
|
||||
import { NextResponse } from 'next/server'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { z } from 'zod'
|
||||
import { db } from '@/db'
|
||||
import { practiceAttachments, type ParsingStatus } from '@/db/schema/practice-attachments'
|
||||
import { canPerformAction } from '@/lib/classroom'
|
||||
import { getDbUserId } from '@/lib/viewer'
|
||||
import {
|
||||
applyCorrections,
|
||||
computeParsingStats,
|
||||
ProblemCorrectionSchema,
|
||||
} from '@/lib/worksheet-parsing'
|
||||
|
||||
interface RouteParams {
|
||||
params: Promise<{ playerId: string; attachmentId: string }>
|
||||
}
|
||||
|
||||
/**
|
||||
* Request body schema for corrections
|
||||
*/
|
||||
const ReviewRequestSchema = z.object({
|
||||
corrections: z.array(ProblemCorrectionSchema).min(1),
|
||||
markAsReviewed: z.boolean().default(false),
|
||||
})
|
||||
|
||||
/**
|
||||
* PATCH - Submit corrections to parsed problems
|
||||
*/
|
||||
export async function PATCH(request: Request, { params }: RouteParams) {
|
||||
try {
|
||||
const { playerId, attachmentId } = await params
|
||||
|
||||
if (!playerId || !attachmentId) {
|
||||
return NextResponse.json({ error: 'Player ID and Attachment ID required' }, { status: 400 })
|
||||
}
|
||||
|
||||
// Authorization check
|
||||
const userId = await getDbUserId()
|
||||
const canReview = await canPerformAction(userId, playerId, 'start-session')
|
||||
if (!canReview) {
|
||||
return NextResponse.json({ error: 'Not authorized' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
const body = await request.json()
|
||||
const parseResult = ReviewRequestSchema.safeParse(body)
|
||||
if (!parseResult.success) {
|
||||
return NextResponse.json({
|
||||
error: 'Invalid request body',
|
||||
details: parseResult.error.issues,
|
||||
}, { status: 400 })
|
||||
}
|
||||
|
||||
const { corrections, markAsReviewed } = parseResult.data
|
||||
|
||||
// Get attachment record
|
||||
const attachment = await db
|
||||
.select()
|
||||
.from(practiceAttachments)
|
||||
.where(eq(practiceAttachments.id, attachmentId))
|
||||
.get()
|
||||
|
||||
if (!attachment) {
|
||||
return NextResponse.json({ error: 'Attachment not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
if (attachment.playerId !== playerId) {
|
||||
return NextResponse.json({ error: 'Attachment not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
// Check if we have parsing results to correct
|
||||
if (!attachment.rawParsingResult) {
|
||||
return NextResponse.json({
|
||||
error: 'No parsing results to correct. Parse the worksheet first.',
|
||||
}, { status: 400 })
|
||||
}
|
||||
|
||||
// Apply corrections to the raw result
|
||||
const correctedResult = applyCorrections(
|
||||
attachment.rawParsingResult,
|
||||
corrections.map((c) => ({
|
||||
problemNumber: c.problemNumber,
|
||||
correctedTerms: c.correctedTerms ?? undefined,
|
||||
correctedStudentAnswer: c.correctedStudentAnswer ?? undefined,
|
||||
shouldExclude: c.shouldExclude,
|
||||
}))
|
||||
)
|
||||
|
||||
// Compute new stats
|
||||
const stats = computeParsingStats(correctedResult)
|
||||
|
||||
// Determine new status
|
||||
let newStatus: ParsingStatus = attachment.parsingStatus ?? 'needs_review'
|
||||
if (markAsReviewed) {
|
||||
// If user explicitly marks as reviewed, set to approved
|
||||
newStatus = 'approved'
|
||||
} else if (!correctedResult.needsReview) {
|
||||
// If all problems now have high confidence, auto-approve
|
||||
newStatus = 'approved'
|
||||
} else {
|
||||
// Still needs review
|
||||
newStatus = 'needs_review'
|
||||
}
|
||||
|
||||
// Update database - store corrected result as approved result
|
||||
await db
|
||||
.update(practiceAttachments)
|
||||
.set({
|
||||
parsingStatus: newStatus,
|
||||
approvedResult: correctedResult,
|
||||
confidenceScore: correctedResult.overallConfidence,
|
||||
needsReview: correctedResult.needsReview,
|
||||
})
|
||||
.where(eq(practiceAttachments.id, attachmentId))
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
status: newStatus,
|
||||
result: correctedResult,
|
||||
stats,
|
||||
correctionsApplied: corrections.length,
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Error applying corrections:', error)
|
||||
return NextResponse.json({ error: 'Failed to apply corrections' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,23 @@
|
||||
import { sqliteTable, text, integer } from 'drizzle-orm/sqlite-core'
|
||||
import { sqliteTable, text, integer, real } from 'drizzle-orm/sqlite-core'
|
||||
import { createId } from '@paralleldrive/cuid2'
|
||||
import { players } from './players'
|
||||
import { sessionPlans } from './session-plans'
|
||||
import { users } from './users'
|
||||
import type { WorksheetParsingResult } from '@/lib/worksheet-parsing'
|
||||
|
||||
/**
|
||||
* Parsing workflow status
|
||||
*/
|
||||
export type ParsingStatus = 'pending' | 'processing' | 'needs_review' | 'approved' | 'failed'
|
||||
|
||||
/**
|
||||
* Practice attachments - photos of student work
|
||||
*
|
||||
* Used primarily for offline practice sessions where parents/teachers
|
||||
* upload photos of the student's physical abacus work.
|
||||
*
|
||||
* Now also supports LLM-powered parsing of worksheet images to extract
|
||||
* problems and student answers automatically.
|
||||
*/
|
||||
export const practiceAttachments = sqliteTable('practice_attachments', {
|
||||
id: text('id')
|
||||
@@ -41,6 +50,29 @@ export const practiceAttachments = sqliteTable('practice_attachments', {
|
||||
// Rotation in degrees (0, 90, 180, or 270) - applied after cropping
|
||||
rotation: integer('rotation').$type<0 | 90 | 180 | 270>().default(0),
|
||||
|
||||
// ============================================================================
|
||||
// LLM Parsing Workflow
|
||||
// ============================================================================
|
||||
|
||||
// Parsing status
|
||||
parsingStatus: text('parsing_status').$type<ParsingStatus>(),
|
||||
parsedAt: text('parsed_at'), // ISO timestamp when parsing completed
|
||||
parsingError: text('parsing_error'), // Error message if parsing failed
|
||||
|
||||
// LLM parsing results (raw from LLM, before user corrections)
|
||||
rawParsingResult: text('raw_parsing_result', { mode: 'json' }).$type<WorksheetParsingResult | null>(),
|
||||
|
||||
// Approved results (after user corrections)
|
||||
approvedResult: text('approved_result', { mode: 'json' }).$type<WorksheetParsingResult | null>(),
|
||||
|
||||
// Confidence and review indicators
|
||||
confidenceScore: real('confidence_score'), // 0-1, from LLM
|
||||
needsReview: integer('needs_review', { mode: 'boolean' }), // True if any problems need manual review
|
||||
|
||||
// Session linkage (for parsed worksheets that created sessions)
|
||||
sessionCreated: integer('session_created', { mode: 'boolean' }), // True if session was created from this parsing
|
||||
createdSessionId: text('created_session_id').references(() => sessionPlans.id, { onDelete: 'set null' }),
|
||||
|
||||
// Audit
|
||||
uploadedBy: text('uploaded_by')
|
||||
.notNull()
|
||||
|
||||
158
apps/web/src/hooks/useLLMCall.ts
Normal file
158
apps/web/src/hooks/useLLMCall.ts
Normal file
@@ -0,0 +1,158 @@
|
||||
/**
|
||||
* React hooks for making LLM calls with progress tracking
|
||||
*
|
||||
* These hooks integrate the LLM client with React Query for proper
|
||||
* state management, caching, and UI feedback.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* import { useLLMCall } from '@/hooks/useLLMCall'
|
||||
* import { z } from 'zod'
|
||||
*
|
||||
* const SentimentSchema = z.object({
|
||||
* sentiment: z.enum(['positive', 'negative', 'neutral']),
|
||||
* confidence: z.number(),
|
||||
* })
|
||||
*
|
||||
* function MyComponent() {
|
||||
* const { mutate, progress, isPending, error, data } = useLLMCall(SentimentSchema)
|
||||
*
|
||||
* return (
|
||||
* <div>
|
||||
* <button onClick={() => mutate({ prompt: 'Analyze: I love this!' })}>
|
||||
* Analyze
|
||||
* </button>
|
||||
* {progress && <div>{progress.message}</div>}
|
||||
* {data && <div>Sentiment: {data.data.sentiment}</div>}
|
||||
* </div>
|
||||
* )
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
|
||||
import { useState, useCallback } from 'react'
|
||||
import { useMutation, type UseMutationOptions } from '@tanstack/react-query'
|
||||
import type { z } from 'zod'
|
||||
import { llm, type LLMProgress, type LLMResponse } from '@/lib/llm'
|
||||
|
||||
/** Request options for LLM call (without schema) */
|
||||
interface LLMCallRequest {
|
||||
prompt: string
|
||||
images?: string[]
|
||||
provider?: string
|
||||
model?: string
|
||||
maxRetries?: number
|
||||
}
|
||||
|
||||
/** Request options for vision call (requires images) */
|
||||
interface LLMVisionRequest extends LLMCallRequest {
|
||||
images: string[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for making type-safe LLM calls with progress tracking
|
||||
*
|
||||
* @param schema - Zod schema for validating the LLM response
|
||||
* @param options - Optional React Query mutation options
|
||||
*/
|
||||
export function useLLMCall<T extends z.ZodType>(
|
||||
schema: T,
|
||||
options?: Omit<
|
||||
UseMutationOptions<LLMResponse<z.infer<T>>, Error, LLMCallRequest>,
|
||||
'mutationFn'
|
||||
>
|
||||
) {
|
||||
const [progress, setProgress] = useState<LLMProgress | null>(null)
|
||||
|
||||
const mutation = useMutation({
|
||||
mutationFn: async (request: LLMCallRequest) => {
|
||||
setProgress(null)
|
||||
return llm.call({
|
||||
...request,
|
||||
schema,
|
||||
onProgress: setProgress,
|
||||
})
|
||||
},
|
||||
onSettled: () => {
|
||||
setProgress(null)
|
||||
},
|
||||
...options,
|
||||
})
|
||||
|
||||
return {
|
||||
...mutation,
|
||||
progress,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for making vision (image + text) LLM calls with progress tracking
|
||||
*
|
||||
* @param schema - Zod schema for validating the LLM response
|
||||
* @param options - Optional React Query mutation options
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const { mutate, progress } = useLLMVision(ImageAnalysisSchema)
|
||||
*
|
||||
* mutate({
|
||||
* prompt: 'Describe this image',
|
||||
* images: ['data:image/jpeg;base64,...'],
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function useLLMVision<T extends z.ZodType>(
|
||||
schema: T,
|
||||
options?: Omit<
|
||||
UseMutationOptions<LLMResponse<z.infer<T>>, Error, LLMVisionRequest>,
|
||||
'mutationFn'
|
||||
>
|
||||
) {
|
||||
const [progress, setProgress] = useState<LLMProgress | null>(null)
|
||||
|
||||
const mutation = useMutation({
|
||||
mutationFn: async (request: LLMVisionRequest) => {
|
||||
setProgress(null)
|
||||
return llm.vision({
|
||||
...request,
|
||||
schema,
|
||||
onProgress: setProgress,
|
||||
})
|
||||
},
|
||||
onSettled: () => {
|
||||
setProgress(null)
|
||||
},
|
||||
...options,
|
||||
})
|
||||
|
||||
return {
|
||||
...mutation,
|
||||
progress,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for getting LLM client status and configuration
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const { providers, isProviderAvailable, defaultProvider } = useLLMStatus()
|
||||
*
|
||||
* if (!isProviderAvailable('openai')) {
|
||||
* return <div>OpenAI is not configured</div>
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export function useLLMStatus() {
|
||||
const getProviders = useCallback(() => llm.getProviders(), [])
|
||||
const isProviderAvailable = useCallback((name: string) => llm.isProviderAvailable(name), [])
|
||||
const getDefaultProvider = useCallback(() => llm.getDefaultProvider(), [])
|
||||
const getDefaultModel = useCallback((provider?: string) => llm.getDefaultModel(provider), [])
|
||||
|
||||
return {
|
||||
providers: getProviders(),
|
||||
isProviderAvailable,
|
||||
defaultProvider: getDefaultProvider(),
|
||||
getDefaultModel,
|
||||
}
|
||||
}
|
||||
52
apps/web/src/lib/llm.ts
Normal file
52
apps/web/src/lib/llm.ts
Normal file
@@ -0,0 +1,52 @@
|
||||
/**
|
||||
* LLM Client Singleton for apps/web
|
||||
*
|
||||
* This module provides a singleton instance of the LLM client that reads
|
||||
* configuration from environment variables. The client supports multiple
|
||||
* providers (OpenAI, Anthropic) and provides type-safe LLM calls with
|
||||
* Zod schema validation.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* import { llm } from '@/lib/llm'
|
||||
* import { z } from 'zod'
|
||||
*
|
||||
* const response = await llm.call({
|
||||
* prompt: 'Analyze this text...',
|
||||
* schema: z.object({ sentiment: z.enum(['positive', 'negative', 'neutral']) }),
|
||||
* })
|
||||
* ```
|
||||
*
|
||||
* @see packages/llm-client/README.md for full documentation
|
||||
*/
|
||||
|
||||
import { LLMClient } from '@soroban/llm-client'
|
||||
|
||||
// Create singleton instance
|
||||
// Configuration is automatically loaded from environment variables:
|
||||
// - LLM_DEFAULT_PROVIDER: Default provider (default: 'openai')
|
||||
// - LLM_DEFAULT_MODEL: Default model override
|
||||
// - LLM_OPENAI_API_KEY: OpenAI API key
|
||||
// - LLM_OPENAI_BASE_URL: OpenAI base URL (optional)
|
||||
// - LLM_ANTHROPIC_API_KEY: Anthropic API key
|
||||
// - LLM_ANTHROPIC_BASE_URL: Anthropic base URL (optional)
|
||||
export const llm = new LLMClient()
|
||||
|
||||
// Re-export types and utilities for convenience
|
||||
export type {
|
||||
LLMClientConfig,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
LLMProgress,
|
||||
LLMProvider,
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderResponse,
|
||||
ValidationFeedback,
|
||||
} from '@soroban/llm-client'
|
||||
|
||||
export {
|
||||
LLMValidationError,
|
||||
LLMApiError,
|
||||
ProviderNotConfiguredError,
|
||||
} from '@soroban/llm-client'
|
||||
77
apps/web/src/lib/worksheet-parsing/index.ts
Normal file
77
apps/web/src/lib/worksheet-parsing/index.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
/**
|
||||
* Worksheet Parsing Module
|
||||
*
|
||||
* Provides LLM-powered parsing of abacus workbook page images.
|
||||
* Extracts arithmetic problems and student answers, then converts
|
||||
* them into practice session data.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* import {
|
||||
* parseWorksheetImage,
|
||||
* convertToSlotResults,
|
||||
* type WorksheetParsingResult,
|
||||
* } from '@/lib/worksheet-parsing'
|
||||
*
|
||||
* // Parse the worksheet image
|
||||
* const result = await parseWorksheetImage(imageDataUrl, {
|
||||
* onProgress: (p) => setProgress(p.message),
|
||||
* })
|
||||
*
|
||||
* // Review and correct if needed
|
||||
* if (result.data.needsReview) {
|
||||
* // Show review UI
|
||||
* }
|
||||
*
|
||||
* // Convert to session data
|
||||
* const { slotResults, summary } = convertToSlotResults(result.data)
|
||||
*
|
||||
* // Create session
|
||||
* await createSession({ playerId, slotResults, status: 'completed' })
|
||||
* ```
|
||||
*/
|
||||
|
||||
// Schemas
|
||||
export {
|
||||
BoundingBoxSchema,
|
||||
ProblemFormatSchema,
|
||||
ProblemTermSchema,
|
||||
ParsedProblemSchema,
|
||||
PageMetadataSchema,
|
||||
WorksheetParsingResultSchema,
|
||||
ProblemCorrectionSchema,
|
||||
ReparseRequestSchema,
|
||||
type BoundingBox,
|
||||
type ProblemFormat,
|
||||
type ParsedProblem,
|
||||
type PageMetadata,
|
||||
type WorksheetParsingResult,
|
||||
type ProblemCorrection,
|
||||
type ReparseRequest,
|
||||
} from './schemas'
|
||||
|
||||
// Parser
|
||||
export {
|
||||
parseWorksheetImage,
|
||||
reparseProblems,
|
||||
computeParsingStats,
|
||||
applyCorrections,
|
||||
type ParseWorksheetOptions,
|
||||
type ParseWorksheetResult,
|
||||
} from './parser'
|
||||
|
||||
// Prompt Builder
|
||||
export {
|
||||
buildWorksheetParsingPrompt,
|
||||
buildReparsePrompt,
|
||||
type PromptOptions,
|
||||
} from './prompt-builder'
|
||||
|
||||
// Session Converter
|
||||
export {
|
||||
convertToSlotResults,
|
||||
validateParsedProblems,
|
||||
computeSkillStats,
|
||||
type ConversionOptions,
|
||||
type ConversionResult,
|
||||
} from './session-converter'
|
||||
215
apps/web/src/lib/worksheet-parsing/parser.ts
Normal file
215
apps/web/src/lib/worksheet-parsing/parser.ts
Normal file
@@ -0,0 +1,215 @@
|
||||
/**
|
||||
* Worksheet Parser
|
||||
*
|
||||
* Uses the LLM client to parse abacus workbook page images
|
||||
* into structured problem data.
|
||||
*/
|
||||
import { llm, type LLMProgress } from '@/lib/llm'
|
||||
import { WorksheetParsingResultSchema, type WorksheetParsingResult } from './schemas'
|
||||
import { buildWorksheetParsingPrompt, type PromptOptions } from './prompt-builder'
|
||||
|
||||
/**
|
||||
* Options for parsing a worksheet
|
||||
*/
|
||||
export interface ParseWorksheetOptions {
|
||||
/** Progress callback for UI updates */
|
||||
onProgress?: (progress: LLMProgress) => void
|
||||
/** Maximum retries on validation failure */
|
||||
maxRetries?: number
|
||||
/** Additional prompt customization */
|
||||
promptOptions?: PromptOptions
|
||||
/** Specific provider to use (defaults to configured default) */
|
||||
provider?: string
|
||||
/** Specific model to use (defaults to configured default) */
|
||||
model?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of worksheet parsing
|
||||
*/
|
||||
export interface ParseWorksheetResult {
|
||||
/** Parsed worksheet data */
|
||||
data: WorksheetParsingResult
|
||||
/** Number of LLM call attempts made */
|
||||
attempts: number
|
||||
/** Provider used */
|
||||
provider: string
|
||||
/** Model used */
|
||||
model: string
|
||||
/** Token usage */
|
||||
usage: {
|
||||
promptTokens: number
|
||||
completionTokens: number
|
||||
totalTokens: number
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse an abacus workbook page image
|
||||
*
|
||||
* @param imageDataUrl - Base64-encoded data URL of the worksheet image
|
||||
* @param options - Parsing options
|
||||
* @returns Structured parsing result
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* import { parseWorksheetImage } from '@/lib/worksheet-parsing'
|
||||
*
|
||||
* const result = await parseWorksheetImage(imageDataUrl, {
|
||||
* onProgress: (p) => console.log(p.message),
|
||||
* })
|
||||
*
|
||||
* console.log(`Found ${result.data.problems.length} problems`)
|
||||
* console.log(`Overall confidence: ${result.data.overallConfidence}`)
|
||||
* ```
|
||||
*/
|
||||
export async function parseWorksheetImage(
|
||||
imageDataUrl: string,
|
||||
options: ParseWorksheetOptions = {}
|
||||
): Promise<ParseWorksheetResult> {
|
||||
const { onProgress, maxRetries = 2, promptOptions = {}, provider, model } = options
|
||||
|
||||
// Build the prompt
|
||||
const prompt = buildWorksheetParsingPrompt(promptOptions)
|
||||
|
||||
// Make the vision call
|
||||
const response = await llm.vision({
|
||||
prompt,
|
||||
images: [imageDataUrl],
|
||||
schema: WorksheetParsingResultSchema,
|
||||
maxRetries,
|
||||
onProgress,
|
||||
provider,
|
||||
model,
|
||||
})
|
||||
|
||||
return {
|
||||
data: response.data,
|
||||
attempts: response.attempts,
|
||||
provider: response.provider,
|
||||
model: response.model,
|
||||
usage: response.usage,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Re-parse specific problems with additional context
|
||||
*
|
||||
* Used when the user provides corrections or hints about specific problems
|
||||
* that were incorrectly parsed in the first attempt.
|
||||
*
|
||||
* @param imageDataUrl - Base64-encoded data URL of the worksheet image
|
||||
* @param problemNumbers - Which problems to focus on
|
||||
* @param additionalContext - User-provided context or hints
|
||||
* @param originalWarnings - Warnings from the original parse
|
||||
* @param options - Parsing options
|
||||
*/
|
||||
export async function reparseProblems(
|
||||
imageDataUrl: string,
|
||||
problemNumbers: number[],
|
||||
additionalContext: string,
|
||||
originalWarnings: string[],
|
||||
options: Omit<ParseWorksheetOptions, 'promptOptions'> = {}
|
||||
): Promise<ParseWorksheetResult> {
|
||||
return parseWorksheetImage(imageDataUrl, {
|
||||
...options,
|
||||
promptOptions: {
|
||||
focusProblemNumbers: problemNumbers,
|
||||
additionalContext: `${additionalContext}
|
||||
|
||||
Previous warnings for these problems:
|
||||
${originalWarnings.map((w) => `- ${w}`).join('\n')}`,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute problem statistics from parsed results
|
||||
*/
|
||||
export function computeParsingStats(result: WorksheetParsingResult) {
|
||||
const problems = result.problems
|
||||
|
||||
// Count problems needing review (low confidence)
|
||||
const lowConfidenceProblems = problems.filter(
|
||||
(p) => p.termsConfidence < 0.7 || p.studentAnswerConfidence < 0.7
|
||||
)
|
||||
|
||||
// Count problems with answers
|
||||
const answeredProblems = problems.filter((p) => p.studentAnswer !== null)
|
||||
|
||||
// Compute accuracy if answers are present
|
||||
const correctAnswers = answeredProblems.filter(
|
||||
(p) => p.studentAnswer === p.correctAnswer
|
||||
)
|
||||
|
||||
return {
|
||||
totalProblems: problems.length,
|
||||
answeredProblems: answeredProblems.length,
|
||||
unansweredProblems: problems.length - answeredProblems.length,
|
||||
correctAnswers: correctAnswers.length,
|
||||
incorrectAnswers: answeredProblems.length - correctAnswers.length,
|
||||
accuracy: answeredProblems.length > 0 ? correctAnswers.length / answeredProblems.length : null,
|
||||
lowConfidenceCount: lowConfidenceProblems.length,
|
||||
problemsNeedingReview: lowConfidenceProblems.map((p) => p.problemNumber),
|
||||
warningCount: result.warnings.length,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge corrections into parsing result
|
||||
*
|
||||
* Creates a new result with user corrections applied.
|
||||
*/
|
||||
export function applyCorrections(
|
||||
result: WorksheetParsingResult,
|
||||
corrections: Array<{
|
||||
problemNumber: number
|
||||
correctedTerms?: number[] | null
|
||||
correctedStudentAnswer?: number | null
|
||||
shouldExclude?: boolean
|
||||
}>
|
||||
): WorksheetParsingResult {
|
||||
const correctionMap = new Map(corrections.map((c) => [c.problemNumber, c]))
|
||||
|
||||
const correctedProblems = result.problems
|
||||
.map((problem) => {
|
||||
const correction = correctionMap.get(problem.problemNumber)
|
||||
if (!correction) return problem
|
||||
if (correction.shouldExclude) return null
|
||||
|
||||
return {
|
||||
...problem,
|
||||
terms: correction.correctedTerms ?? problem.terms,
|
||||
correctAnswer: correction.correctedTerms
|
||||
? correction.correctedTerms.reduce((sum, t) => sum + t, 0)
|
||||
: problem.correctAnswer,
|
||||
studentAnswer:
|
||||
correction.correctedStudentAnswer !== undefined
|
||||
? correction.correctedStudentAnswer
|
||||
: problem.studentAnswer,
|
||||
// Boost confidence since user verified
|
||||
termsConfidence: correction.correctedTerms ? 1.0 : problem.termsConfidence,
|
||||
studentAnswerConfidence:
|
||||
correction.correctedStudentAnswer !== undefined
|
||||
? 1.0
|
||||
: problem.studentAnswerConfidence,
|
||||
}
|
||||
})
|
||||
.filter((p): p is NonNullable<typeof p> => p !== null)
|
||||
|
||||
// Recalculate overall confidence
|
||||
const avgConfidence =
|
||||
correctedProblems.reduce(
|
||||
(sum, p) => sum + (p.termsConfidence + p.studentAnswerConfidence) / 2,
|
||||
0
|
||||
) / correctedProblems.length
|
||||
|
||||
return {
|
||||
...result,
|
||||
problems: correctedProblems,
|
||||
overallConfidence: avgConfidence,
|
||||
needsReview: correctedProblems.some(
|
||||
(p) => p.termsConfidence < 0.7 || p.studentAnswerConfidence < 0.7
|
||||
),
|
||||
}
|
||||
}
|
||||
149
apps/web/src/lib/worksheet-parsing/prompt-builder.ts
Normal file
149
apps/web/src/lib/worksheet-parsing/prompt-builder.ts
Normal file
@@ -0,0 +1,149 @@
|
||||
/**
|
||||
* Prompt Builder for Worksheet Parsing
|
||||
*
|
||||
* Constructs the prompt used to parse abacus workbook pages.
|
||||
* The prompt provides context about worksheet formats and
|
||||
* guides the LLM on how to extract problem data.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Options for customizing the parsing prompt
|
||||
*/
|
||||
export interface PromptOptions {
|
||||
/** Additional context from a previous parse attempt (for re-parsing) */
|
||||
additionalContext?: string
|
||||
/** Specific problem numbers to focus on (for re-parsing) */
|
||||
focusProblemNumbers?: number[]
|
||||
/** Hint about expected format if known */
|
||||
expectedFormat?: 'vertical' | 'linear' | 'mixed'
|
||||
/** Expected number of problems (if known from worksheet metadata) */
|
||||
expectedProblemCount?: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the main worksheet parsing prompt
|
||||
*
|
||||
* This prompt is designed to guide the LLM in extracting
|
||||
* structured data from abacus workbook page images.
|
||||
*/
|
||||
export function buildWorksheetParsingPrompt(options: PromptOptions = {}): string {
|
||||
const parts: string[] = []
|
||||
|
||||
// Main task description
|
||||
parts.push(`You are analyzing an image of an abacus workbook page. Your task is to extract all arithmetic problems from the page along with any student answers written in the answer boxes.
|
||||
|
||||
## Worksheet Context
|
||||
|
||||
This is a Japanese soroban (abacus) practice worksheet. These worksheets typically contain:
|
||||
- 1-4 rows of problems
|
||||
- 8-10 problems per row (32-40 problems on a full page)
|
||||
- Each problem has 2-7 terms (numbers to add or subtract)
|
||||
- Problems are either VERTICAL format (stacked columns) or LINEAR format (horizontal equations)
|
||||
|
||||
## Problem Format Recognition
|
||||
|
||||
**VERTICAL FORMAT:**
|
||||
Problems are arranged in columns with numbers stacked vertically. Addition is implied between numbers. Subtraction is indicated by a minus sign or horizontal line. The answer box is at the bottom.
|
||||
|
||||
Example:
|
||||
45
|
||||
-17
|
||||
+ 8
|
||||
----
|
||||
[36] ← answer box
|
||||
|
||||
In this case: terms = [45, -17, 8], correctAnswer = 36
|
||||
|
||||
**LINEAR FORMAT:**
|
||||
Problems are written as horizontal equations with operators between numbers.
|
||||
|
||||
Example: 45 - 17 + 8 = [36]
|
||||
|
||||
In this case: terms = [45, -17, 8], correctAnswer = 36
|
||||
|
||||
## Student Answer Reading
|
||||
|
||||
- Look carefully at the answer boxes/spaces for student handwriting
|
||||
- Student handwriting may be messy - try to interpret digits carefully
|
||||
- If an answer is empty, set studentAnswer to null
|
||||
- If you cannot confidently read the answer, set studentAnswer to null and lower studentAnswerConfidence
|
||||
- Common handwriting confusions to watch for:
|
||||
- 1 vs 7 (some students cross their 7s)
|
||||
- 4 vs 9
|
||||
- 5 vs 6
|
||||
- 0 vs 6
|
||||
|
||||
## Bounding Boxes
|
||||
|
||||
For each problem, provide bounding boxes in normalized coordinates (0-1):
|
||||
- x, y: top-left corner as fraction of image dimensions
|
||||
- width, height: size as fraction of image dimensions
|
||||
|
||||
The problemBoundingBox should encompass the entire problem including terms and answer area.
|
||||
The answerBoundingBox should tightly surround just the answer area.`)
|
||||
|
||||
// Add expected format hint if provided
|
||||
if (options.expectedFormat) {
|
||||
parts.push(`
|
||||
|
||||
## Format Hint
|
||||
The problems on this page are expected to be in ${options.expectedFormat.toUpperCase()} format.`)
|
||||
}
|
||||
|
||||
// Add expected count if provided
|
||||
if (options.expectedProblemCount) {
|
||||
parts.push(`
|
||||
|
||||
## Expected Problem Count
|
||||
This worksheet should contain approximately ${options.expectedProblemCount} problems. If you detect significantly more or fewer, double-check for missed or duplicate problems.`)
|
||||
}
|
||||
|
||||
// Add focus problems for re-parsing
|
||||
if (options.focusProblemNumbers && options.focusProblemNumbers.length > 0) {
|
||||
parts.push(`
|
||||
|
||||
## Focus Problems
|
||||
Pay special attention to problems: ${options.focusProblemNumbers.join(', ')}. The previous parsing attempt had issues with these problems.`)
|
||||
}
|
||||
|
||||
// Add additional context from user
|
||||
if (options.additionalContext) {
|
||||
parts.push(`
|
||||
|
||||
## Additional Context from User
|
||||
${options.additionalContext}`)
|
||||
}
|
||||
|
||||
// Final instructions
|
||||
parts.push(`
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. **Reading Order**: Extract problems in reading order (left to right, top to bottom)
|
||||
2. **Problem Numbers**: Use the printed problem numbers on the worksheet (1, 2, 3, etc.)
|
||||
3. **Term Signs**: First term is always positive. Subsequent terms are positive for addition, negative for subtraction
|
||||
4. **Confidence Scores**: Be honest about confidence - lower scores help identify problems needing review
|
||||
5. **Warnings**: Include any issues you notice (cropped problems, smudges, unclear digits)
|
||||
6. **needsReview**: Set to true if any problem has confidence below 0.7 or significant warnings
|
||||
|
||||
Now analyze the worksheet image and extract all problems.`)
|
||||
|
||||
return parts.join('')
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a prompt for re-parsing specific problems with additional context
|
||||
*/
|
||||
export function buildReparsePrompt(
|
||||
problemNumbers: number[],
|
||||
additionalContext: string,
|
||||
originalWarnings: string[]
|
||||
): string {
|
||||
return buildWorksheetParsingPrompt({
|
||||
focusProblemNumbers: problemNumbers,
|
||||
additionalContext: `${additionalContext}
|
||||
|
||||
Previous warnings for these problems:
|
||||
${originalWarnings.map((w) => `- ${w}`).join('\n')}`,
|
||||
})
|
||||
}
|
||||
287
apps/web/src/lib/worksheet-parsing/schemas.ts
Normal file
287
apps/web/src/lib/worksheet-parsing/schemas.ts
Normal file
@@ -0,0 +1,287 @@
|
||||
/**
|
||||
* Worksheet Parsing Schemas
|
||||
*
|
||||
* These Zod schemas define the structure of LLM responses when parsing
|
||||
* abacus workbook pages. The .describe() annotations are critical -
|
||||
* they are automatically extracted and included in the LLM prompt.
|
||||
*/
|
||||
import { z } from 'zod'
|
||||
|
||||
/**
|
||||
* Bounding box in normalized coordinates (0-1)
|
||||
* Represents a rectangular region on the worksheet image
|
||||
*/
|
||||
export const BoundingBoxSchema = z
|
||||
.object({
|
||||
x: z
|
||||
.number()
|
||||
.min(0)
|
||||
.max(1)
|
||||
.describe('Left edge of the box as a fraction of image width (0 = left edge, 1 = right edge)'),
|
||||
y: z
|
||||
.number()
|
||||
.min(0)
|
||||
.max(1)
|
||||
.describe('Top edge of the box as a fraction of image height (0 = top edge, 1 = bottom edge)'),
|
||||
width: z
|
||||
.number()
|
||||
.min(0)
|
||||
.max(1)
|
||||
.describe('Width of the box as a fraction of image width'),
|
||||
height: z
|
||||
.number()
|
||||
.min(0)
|
||||
.max(1)
|
||||
.describe('Height of the box as a fraction of image height'),
|
||||
})
|
||||
.describe('Rectangular region on the worksheet image, in normalized 0-1 coordinates')
|
||||
|
||||
export type BoundingBox = z.infer<typeof BoundingBoxSchema>
|
||||
|
||||
/**
|
||||
* Problem format detected in the worksheet
|
||||
*/
|
||||
export const ProblemFormatSchema = z
|
||||
.enum(['vertical', 'linear'])
|
||||
.describe(
|
||||
'Format of the problem: "vertical" for stacked column addition/subtraction with answer box below, ' +
|
||||
'"linear" for horizontal format like "a + b - c = ___"'
|
||||
)
|
||||
|
||||
export type ProblemFormat = z.infer<typeof ProblemFormatSchema>
|
||||
|
||||
/**
|
||||
* Single term in a problem (number with operation)
|
||||
*/
|
||||
export const ProblemTermSchema = z
|
||||
.number()
|
||||
.int()
|
||||
.describe(
|
||||
'A single term in the problem. Positive numbers represent addition, ' +
|
||||
'negative numbers represent subtraction. The first term is always positive. ' +
|
||||
'Example: for "45 - 17 + 8", terms are [45, -17, 8]'
|
||||
)
|
||||
|
||||
/**
|
||||
* A single parsed problem from the worksheet
|
||||
*/
|
||||
export const ParsedProblemSchema = z
|
||||
.object({
|
||||
// Identification
|
||||
problemNumber: z
|
||||
.number()
|
||||
.int()
|
||||
.min(1)
|
||||
.describe('The problem number as printed on the worksheet (1, 2, 3, etc.)'),
|
||||
row: z
|
||||
.number()
|
||||
.int()
|
||||
.min(1)
|
||||
.describe('Which row of problems this belongs to (1 = top row, 2 = second row, etc.)'),
|
||||
column: z
|
||||
.number()
|
||||
.int()
|
||||
.min(1)
|
||||
.describe('Which column position in the row (1 = leftmost, counting right)'),
|
||||
|
||||
// Problem content
|
||||
format: ProblemFormatSchema,
|
||||
terms: z
|
||||
.array(ProblemTermSchema)
|
||||
.min(2)
|
||||
.max(7)
|
||||
.describe(
|
||||
'All terms in the problem, in order. First term is positive. ' +
|
||||
'Subsequent terms are positive for addition, negative for subtraction. ' +
|
||||
'Example: "45 - 17 + 8" → [45, -17, 8]'
|
||||
),
|
||||
correctAnswer: z
|
||||
.number()
|
||||
.int()
|
||||
.describe('The mathematically correct answer to this problem'),
|
||||
|
||||
// Student work
|
||||
studentAnswer: z
|
||||
.number()
|
||||
.int()
|
||||
.nullable()
|
||||
.describe(
|
||||
'The answer the student wrote, if readable. Null if the answer box is empty, ' +
|
||||
'illegible, or you cannot confidently read the student\'s handwriting'
|
||||
),
|
||||
studentAnswerConfidence: z
|
||||
.number()
|
||||
.min(0)
|
||||
.max(1)
|
||||
.describe(
|
||||
'Confidence in reading the student\'s answer (0 = not readable/empty, 1 = perfectly clear). ' +
|
||||
'Use 0.5-0.7 for somewhat legible, 0.8-0.9 for mostly clear, 1.0 for crystal clear'
|
||||
),
|
||||
|
||||
// Problem extraction confidence
|
||||
termsConfidence: z
|
||||
.number()
|
||||
.min(0)
|
||||
.max(1)
|
||||
.describe(
|
||||
'Confidence in correctly reading all the problem terms (0 = very unsure, 1 = certain). ' +
|
||||
'Lower confidence if digits are smudged, cropped, or partially obscured'
|
||||
),
|
||||
|
||||
// Bounding boxes for UI highlighting
|
||||
problemBoundingBox: BoundingBoxSchema.describe(
|
||||
'Bounding box around the entire problem (including all terms and answer area)'
|
||||
),
|
||||
answerBoundingBox: BoundingBoxSchema.nullable().describe(
|
||||
'Bounding box around just the student\'s answer area. Null if no answer area is visible'
|
||||
),
|
||||
})
|
||||
.describe('A single arithmetic problem extracted from the worksheet')
|
||||
|
||||
export type ParsedProblem = z.infer<typeof ParsedProblemSchema>
|
||||
|
||||
/**
|
||||
* Detected worksheet format
|
||||
*/
|
||||
export const WorksheetFormatSchema = z
|
||||
.enum(['vertical', 'linear', 'mixed'])
|
||||
.describe(
|
||||
'Overall format of problems on this page: ' +
|
||||
'"vertical" if all problems are stacked column format, ' +
|
||||
'"linear" if all are horizontal equation format, ' +
|
||||
'"mixed" if the page contains both formats'
|
||||
)
|
||||
|
||||
/**
|
||||
* Page metadata extracted from the worksheet
|
||||
*/
|
||||
export const PageMetadataSchema = z
|
||||
.object({
|
||||
lessonId: z
|
||||
.string()
|
||||
.nullable()
|
||||
.describe(
|
||||
'Lesson identifier if printed on the page (e.g., "Lesson 5", "L5", "Unit 2 Lesson 3"). ' +
|
||||
'Null if no lesson identifier is visible'
|
||||
),
|
||||
weekId: z
|
||||
.string()
|
||||
.nullable()
|
||||
.describe(
|
||||
'Week identifier if printed on the page (e.g., "Week 4", "W4"). ' +
|
||||
'Null if no week identifier is visible'
|
||||
),
|
||||
pageNumber: z
|
||||
.number()
|
||||
.int()
|
||||
.nullable()
|
||||
.describe(
|
||||
'Page number if printed on the page. Null if no page number is visible'
|
||||
),
|
||||
detectedFormat: WorksheetFormatSchema,
|
||||
totalRows: z
|
||||
.number()
|
||||
.int()
|
||||
.min(1)
|
||||
.max(6)
|
||||
.describe('Number of rows of problems on this page (typically 1-4)'),
|
||||
problemsPerRow: z
|
||||
.number()
|
||||
.int()
|
||||
.min(1)
|
||||
.max(12)
|
||||
.describe('Average number of problems per row (typically 8-10)'),
|
||||
})
|
||||
.describe('Metadata about the worksheet page layout and identifiers')
|
||||
|
||||
export type PageMetadata = z.infer<typeof PageMetadataSchema>
|
||||
|
||||
/**
|
||||
* Complete worksheet parsing result
|
||||
*/
|
||||
export const WorksheetParsingResultSchema = z
|
||||
.object({
|
||||
problems: z
|
||||
.array(ParsedProblemSchema)
|
||||
.min(1)
|
||||
.describe(
|
||||
'All problems detected on the worksheet, in reading order (left to right, top to bottom)'
|
||||
),
|
||||
pageMetadata: PageMetadataSchema,
|
||||
overallConfidence: z
|
||||
.number()
|
||||
.min(0)
|
||||
.max(1)
|
||||
.describe(
|
||||
'Overall confidence in the parsing accuracy (0 = very uncertain, 1 = highly confident). ' +
|
||||
'Based on image quality, problem clarity, and answer legibility'
|
||||
),
|
||||
warnings: z
|
||||
.array(z.string())
|
||||
.describe(
|
||||
'List of issues encountered during parsing, such as: ' +
|
||||
'"Problem 5 terms partially obscured", ' +
|
||||
'"Row 2 problems may be cropped", ' +
|
||||
'"Student handwriting difficult to read on problems 3, 7, 12"'
|
||||
),
|
||||
needsReview: z
|
||||
.boolean()
|
||||
.describe(
|
||||
'True if any problems have low confidence or warnings that require human review ' +
|
||||
'before creating a practice session'
|
||||
),
|
||||
})
|
||||
.describe('Complete result of parsing an abacus workbook page')
|
||||
|
||||
export type WorksheetParsingResult = z.infer<typeof WorksheetParsingResultSchema>
|
||||
|
||||
/**
|
||||
* User correction to a parsed problem
|
||||
*/
|
||||
export const ProblemCorrectionSchema = z
|
||||
.object({
|
||||
problemNumber: z
|
||||
.number()
|
||||
.int()
|
||||
.min(1)
|
||||
.describe('The problem number being corrected'),
|
||||
correctedTerms: z
|
||||
.array(ProblemTermSchema)
|
||||
.nullable()
|
||||
.describe('Corrected terms if the LLM got them wrong. Null to keep original'),
|
||||
correctedStudentAnswer: z
|
||||
.number()
|
||||
.int()
|
||||
.nullable()
|
||||
.describe('Corrected student answer. Null means empty/not answered'),
|
||||
shouldExclude: z
|
||||
.boolean()
|
||||
.describe('True to exclude this problem from the session (e.g., illegible)'),
|
||||
note: z
|
||||
.string()
|
||||
.nullable()
|
||||
.describe('Optional note explaining the correction'),
|
||||
})
|
||||
.describe('User correction to a single parsed problem')
|
||||
|
||||
export type ProblemCorrection = z.infer<typeof ProblemCorrectionSchema>
|
||||
|
||||
/**
|
||||
* Request to re-parse with additional context
|
||||
*/
|
||||
export const ReparseRequestSchema = z
|
||||
.object({
|
||||
problemNumbers: z
|
||||
.array(z.number().int().min(1))
|
||||
.describe('Which problems to re-parse'),
|
||||
additionalContext: z
|
||||
.string()
|
||||
.describe(
|
||||
'Additional instructions for the LLM, such as: ' +
|
||||
'"The student writes 7s with a line through them", ' +
|
||||
'"Problem 5 has a 3-digit answer, not 2-digit"'
|
||||
),
|
||||
})
|
||||
.describe('Request to re-parse specific problems with additional context')
|
||||
|
||||
export type ReparseRequest = z.infer<typeof ReparseRequestSchema>
|
||||
225
apps/web/src/lib/worksheet-parsing/session-converter.ts
Normal file
225
apps/web/src/lib/worksheet-parsing/session-converter.ts
Normal file
@@ -0,0 +1,225 @@
|
||||
/**
|
||||
* Session Converter
|
||||
*
|
||||
* Converts parsed worksheet data into SlotResults that can be
|
||||
* used to create an offline practice session.
|
||||
*/
|
||||
import type { SlotResult, GeneratedProblem } from '@/db/schema/session-plans'
|
||||
import type { WorksheetParsingResult, ParsedProblem } from './schemas'
|
||||
import { analyzeRequiredSkills } from '@/utils/problemGenerator'
|
||||
|
||||
/**
|
||||
* Options for session conversion
|
||||
*/
|
||||
export interface ConversionOptions {
|
||||
/** Part number to assign to all problems (default: 1) */
|
||||
partNumber?: 1 | 2 | 3
|
||||
/** Source identifier for the session results */
|
||||
source?: 'practice' | 'recency-refresh'
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of session conversion
|
||||
*/
|
||||
export interface ConversionResult {
|
||||
/** Converted slot results ready for session creation */
|
||||
slotResults: Omit<SlotResult, 'timestamp'>[]
|
||||
/** Summary statistics */
|
||||
summary: {
|
||||
totalProblems: number
|
||||
answeredProblems: number
|
||||
correctAnswers: number
|
||||
incorrectAnswers: number
|
||||
skippedProblems: number
|
||||
accuracy: number | null
|
||||
}
|
||||
/** Skills that were exercised across all problems */
|
||||
skillsExercised: string[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a single parsed problem to a GeneratedProblem
|
||||
*/
|
||||
function toGeneratedProblem(parsed: ParsedProblem): GeneratedProblem {
|
||||
// Calculate correct answer from terms
|
||||
const correctAnswer = parsed.terms.reduce((sum, term) => sum + term, 0)
|
||||
|
||||
// Infer skills from terms
|
||||
const skillsRequired = analyzeRequiredSkills(parsed.terms, correctAnswer)
|
||||
|
||||
return {
|
||||
terms: parsed.terms,
|
||||
answer: correctAnswer,
|
||||
skillsRequired,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a parsed problem to a SlotResult
|
||||
*/
|
||||
function toSlotResult(
|
||||
parsed: ParsedProblem,
|
||||
slotIndex: number,
|
||||
options: ConversionOptions
|
||||
): Omit<SlotResult, 'timestamp'> {
|
||||
const problem = toGeneratedProblem(parsed)
|
||||
const studentAnswer = parsed.studentAnswer ?? 0
|
||||
const isCorrect = parsed.studentAnswer !== null && parsed.studentAnswer === problem.answer
|
||||
|
||||
return {
|
||||
partNumber: options.partNumber ?? 1,
|
||||
slotIndex,
|
||||
problem,
|
||||
studentAnswer,
|
||||
isCorrect,
|
||||
responseTimeMs: 0, // Unknown for offline work
|
||||
skillsExercised: problem.skillsRequired,
|
||||
usedOnScreenAbacus: false,
|
||||
hadHelp: false,
|
||||
incorrectAttempts: isCorrect ? 0 : parsed.studentAnswer !== null ? 1 : 0,
|
||||
source: options.source,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert parsed worksheet results to SlotResults
|
||||
*
|
||||
* Filters out problems that were marked for exclusion and converts
|
||||
* the remaining problems into the format needed for session creation.
|
||||
*
|
||||
* @param parsingResult - The parsed worksheet data
|
||||
* @param options - Conversion options
|
||||
* @returns Conversion result with slot results and summary
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* import { convertToSlotResults } from '@/lib/worksheet-parsing'
|
||||
*
|
||||
* const result = convertToSlotResults(parsingResult, { partNumber: 1 })
|
||||
*
|
||||
* // Create session with results
|
||||
* await createSession({
|
||||
* playerId,
|
||||
* status: 'completed',
|
||||
* slotResults: result.slotResults,
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
export function convertToSlotResults(
|
||||
parsingResult: WorksheetParsingResult,
|
||||
options: ConversionOptions = {}
|
||||
): ConversionResult {
|
||||
const problems = parsingResult.problems
|
||||
const slotResults: Omit<SlotResult, 'timestamp'>[] = []
|
||||
const allSkills = new Set<string>()
|
||||
|
||||
let answeredCount = 0
|
||||
let correctCount = 0
|
||||
|
||||
for (let i = 0; i < problems.length; i++) {
|
||||
const parsed = problems[i]
|
||||
const slotResult = toSlotResult(parsed, i, options)
|
||||
slotResults.push(slotResult)
|
||||
|
||||
// Track skills
|
||||
for (const skill of slotResult.skillsExercised) {
|
||||
allSkills.add(skill)
|
||||
}
|
||||
|
||||
// Track statistics
|
||||
if (parsed.studentAnswer !== null) {
|
||||
answeredCount++
|
||||
if (slotResult.isCorrect) {
|
||||
correctCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const skippedCount = problems.length - answeredCount
|
||||
|
||||
return {
|
||||
slotResults,
|
||||
summary: {
|
||||
totalProblems: problems.length,
|
||||
answeredProblems: answeredCount,
|
||||
correctAnswers: correctCount,
|
||||
incorrectAnswers: answeredCount - correctCount,
|
||||
skippedProblems: skippedCount,
|
||||
accuracy: answeredCount > 0 ? correctCount / answeredCount : null,
|
||||
},
|
||||
skillsExercised: Array.from(allSkills),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate that parsed problems have reasonable values
|
||||
*
|
||||
* Returns warnings for any issues found.
|
||||
*/
|
||||
export function validateParsedProblems(
|
||||
problems: ParsedProblem[]
|
||||
): { valid: boolean; warnings: string[] } {
|
||||
const warnings: string[] = []
|
||||
|
||||
for (const problem of problems) {
|
||||
// Check that correct answer matches term sum
|
||||
const expectedAnswer = problem.terms.reduce((sum, t) => sum + t, 0)
|
||||
if (problem.correctAnswer !== expectedAnswer) {
|
||||
warnings.push(
|
||||
`Problem ${problem.problemNumber}: correctAnswer (${problem.correctAnswer}) ` +
|
||||
`doesn't match sum of terms (${expectedAnswer})`
|
||||
)
|
||||
}
|
||||
|
||||
// Check for negative answers (valid but unusual)
|
||||
if (expectedAnswer < 0) {
|
||||
warnings.push(
|
||||
`Problem ${problem.problemNumber}: negative answer (${expectedAnswer}) - verify this is correct`
|
||||
)
|
||||
}
|
||||
|
||||
// Check for very large numbers (may indicate misread)
|
||||
if (Math.abs(expectedAnswer) > 9999) {
|
||||
warnings.push(
|
||||
`Problem ${problem.problemNumber}: very large answer (${expectedAnswer}) - verify reading`
|
||||
)
|
||||
}
|
||||
|
||||
// Check for low confidence
|
||||
if (problem.termsConfidence < 0.5) {
|
||||
warnings.push(
|
||||
`Problem ${problem.problemNumber}: very low term confidence (${problem.termsConfidence.toFixed(2)})`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
valid: warnings.length === 0,
|
||||
warnings,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute aggregate skill statistics from slot results
|
||||
*/
|
||||
export function computeSkillStats(
|
||||
slotResults: Omit<SlotResult, 'timestamp'>[]
|
||||
): Map<string, { correct: number; incorrect: number; total: number }> {
|
||||
const skillStats = new Map<string, { correct: number; incorrect: number; total: number }>()
|
||||
|
||||
for (const result of slotResults) {
|
||||
for (const skill of result.skillsExercised) {
|
||||
const stats = skillStats.get(skill) ?? { correct: 0, incorrect: 0, total: 0 }
|
||||
stats.total++
|
||||
if (result.isCorrect) {
|
||||
stats.correct++
|
||||
} else if (result.studentAnswer !== 0) {
|
||||
// Only count as incorrect if student answered
|
||||
stats.incorrect++
|
||||
}
|
||||
skillStats.set(skill, stats)
|
||||
}
|
||||
}
|
||||
|
||||
return skillStats
|
||||
}
|
||||
257
packages/llm-client/README.md
Normal file
257
packages/llm-client/README.md
Normal file
@@ -0,0 +1,257 @@
|
||||
# @soroban/llm-client
|
||||
|
||||
Type-safe LLM client with multi-provider support, Zod schema validation, and retry logic with validation feedback.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multi-provider support**: OpenAI, Anthropic (more coming)
|
||||
- **Type-safe responses**: Zod schema validation with full TypeScript inference
|
||||
- **Schema-driven prompts**: Zod `.describe()` annotations are automatically included in prompts
|
||||
- **Retry with feedback**: Failed validations are fed back to the LLM for correction
|
||||
- **Vision support**: Pass images for multimodal requests
|
||||
- **Progress callbacks**: Track LLM call progress for UI feedback
|
||||
- **Environment-based config**: Configure providers via env vars
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pnpm add @soroban/llm-client zod
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Set environment variables for your providers:
|
||||
|
||||
```bash
|
||||
# Default provider
|
||||
LLM_DEFAULT_PROVIDER=openai
|
||||
LLM_DEFAULT_MODEL=gpt-4o
|
||||
|
||||
# OpenAI
|
||||
LLM_OPENAI_API_KEY=sk-...
|
||||
LLM_OPENAI_BASE_URL=https://api.openai.com/v1 # optional
|
||||
|
||||
# Anthropic
|
||||
LLM_ANTHROPIC_API_KEY=sk-ant-...
|
||||
LLM_ANTHROPIC_BASE_URL=https://api.anthropic.com/v1 # optional
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```typescript
|
||||
import { LLMClient } from '@soroban/llm-client'
|
||||
import { z } from 'zod'
|
||||
|
||||
const llm = new LLMClient()
|
||||
|
||||
// Define your response schema with descriptions
|
||||
// IMPORTANT: Use .describe() on every field - these are sent to the LLM!
|
||||
const SentimentSchema = z.object({
|
||||
sentiment: z.enum(['positive', 'negative', 'neutral'])
|
||||
.describe('The overall sentiment detected in the text'),
|
||||
confidence: z.number().min(0).max(1)
|
||||
.describe('How confident the analysis is, from 0 (uncertain) to 1 (certain)'),
|
||||
reasoning: z.string()
|
||||
.describe('Brief explanation of why this sentiment was detected'),
|
||||
}).describe('Sentiment analysis result')
|
||||
|
||||
// Make a type-safe call
|
||||
const response = await llm.call({
|
||||
prompt: 'Analyze the sentiment of: "I love this product!"',
|
||||
schema: SentimentSchema,
|
||||
})
|
||||
|
||||
// response.data is fully typed
|
||||
console.log(response.data.sentiment) // 'positive'
|
||||
console.log(response.data.confidence) // 0.95
|
||||
```
|
||||
|
||||
### Schema Descriptions (Critical!)
|
||||
|
||||
**The `.describe()` method is how you communicate expectations to the LLM.** Every field description you add is automatically extracted and included in the prompt sent to the LLM.
|
||||
|
||||
```typescript
|
||||
// ❌ Bad: No context for the LLM
|
||||
const BadSchema = z.object({
|
||||
value: z.number(),
|
||||
items: z.array(z.string()),
|
||||
})
|
||||
|
||||
// ✅ Good: Rich context guides LLM responses
|
||||
const GoodSchema = z.object({
|
||||
value: z.number()
|
||||
.describe('The total price in USD, with up to 2 decimal places'),
|
||||
items: z.array(
|
||||
z.string().describe('Product name exactly as shown on receipt')
|
||||
).describe('All line items from the receipt'),
|
||||
}).describe('Parsed receipt data')
|
||||
```
|
||||
|
||||
When you call `llm.call()`, the prompt sent to the LLM includes:
|
||||
|
||||
```
|
||||
[Your prompt here]
|
||||
|
||||
## Response Format
|
||||
|
||||
Respond with JSON matching the following structure:
|
||||
|
||||
### Field Descriptions
|
||||
- **Response**: Parsed receipt data
|
||||
- **value**: The total price in USD, with up to 2 decimal places
|
||||
- **items**: All line items from the receipt
|
||||
- **items[]**: Product name exactly as shown on receipt
|
||||
|
||||
### JSON Schema
|
||||
[Full JSON schema for validation]
|
||||
```
|
||||
|
||||
This ensures the LLM understands:
|
||||
1. What each field represents semantically
|
||||
2. What format/constraints to follow
|
||||
3. How nested structures should be filled
|
||||
|
||||
### Vision Requests
|
||||
|
||||
```typescript
|
||||
const ImageAnalysisSchema = z.object({
|
||||
description: z.string()
|
||||
.describe('A detailed description of the main subject'),
|
||||
objects: z.array(z.string().describe('Name of an object visible in the image'))
|
||||
.describe('All distinct objects identified in the image'),
|
||||
}).describe('Image analysis result')
|
||||
|
||||
const response = await llm.vision({
|
||||
prompt: 'Describe what you see in this image',
|
||||
images: ['data:image/jpeg;base64,...'],
|
||||
schema: ImageAnalysisSchema,
|
||||
})
|
||||
```
|
||||
|
||||
### Progress Tracking
|
||||
|
||||
```typescript
|
||||
const response = await llm.call({
|
||||
prompt: 'Complex analysis...',
|
||||
schema: MySchema,
|
||||
onProgress: (progress) => {
|
||||
console.log(`${progress.stage}: ${progress.message}`)
|
||||
// 'calling: Calling LLM...'
|
||||
// 'validating: Validating response...'
|
||||
// 'retrying: Retry 1/2: fixing sentiment'
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Provider Selection
|
||||
|
||||
```typescript
|
||||
// Use a specific provider
|
||||
const response = await llm.call({
|
||||
prompt: 'Hello!',
|
||||
schema: ResponseSchema,
|
||||
provider: 'anthropic',
|
||||
model: 'claude-sonnet-4-20250514',
|
||||
})
|
||||
|
||||
// Check available providers
|
||||
console.log(llm.getProviders()) // ['openai', 'anthropic']
|
||||
console.log(llm.isProviderAvailable('openai')) // true
|
||||
```
|
||||
|
||||
### Retry Configuration
|
||||
|
||||
```typescript
|
||||
const response = await llm.call({
|
||||
prompt: 'Extract data...',
|
||||
schema: StrictSchema,
|
||||
maxRetries: 3, // Default is 2
|
||||
})
|
||||
|
||||
// If validation fails, the LLM receives feedback like:
|
||||
// "PREVIOUS ATTEMPT HAD VALIDATION ERROR:
|
||||
// Field: items.0.price
|
||||
// Error: Expected number, received string
|
||||
// Please correct this error and provide a valid response."
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### `LLMClient`
|
||||
|
||||
Main client class for making LLM calls.
|
||||
|
||||
#### Constructor
|
||||
|
||||
```typescript
|
||||
new LLMClient(configOverrides?: Partial<LLMClientConfig>, env?: Record<string, string>)
|
||||
```
|
||||
|
||||
#### Methods
|
||||
|
||||
- `call<T>(request: LLMRequest<T>): Promise<LLMResponse<T>>` - Make a structured LLM call
|
||||
- `vision<T>(request: LLMRequest<T> & { images: string[] }): Promise<LLMResponse<T>>` - Vision call
|
||||
- `getProviders(): string[]` - List configured providers
|
||||
- `isProviderAvailable(name: string): boolean` - Check if provider is configured
|
||||
- `getDefaultProvider(): string` - Get default provider name
|
||||
- `getDefaultModel(provider?: string): string` - Get default model
|
||||
|
||||
### Types
|
||||
|
||||
```typescript
|
||||
interface LLMRequest<T extends z.ZodType> {
|
||||
prompt: string
|
||||
images?: string[]
|
||||
schema: T
|
||||
provider?: string
|
||||
model?: string
|
||||
maxRetries?: number
|
||||
onProgress?: (progress: LLMProgress) => void
|
||||
}
|
||||
|
||||
interface LLMResponse<T> {
|
||||
data: T
|
||||
usage: { promptTokens: number; completionTokens: number; totalTokens: number }
|
||||
attempts: number
|
||||
provider: string
|
||||
model: string
|
||||
}
|
||||
|
||||
interface LLMProgress {
|
||||
stage: 'preparing' | 'calling' | 'validating' | 'retrying'
|
||||
attempt: number
|
||||
maxAttempts: number
|
||||
message: string
|
||||
validationError?: ValidationFeedback
|
||||
}
|
||||
```
|
||||
|
||||
## Adding Custom Providers
|
||||
|
||||
You can extend the `BaseProvider` class to add support for additional LLM providers:
|
||||
|
||||
```typescript
|
||||
import { BaseProvider, ProviderConfig, ProviderRequest, ProviderResponse } from '@soroban/llm-client'
|
||||
|
||||
class MyProvider extends BaseProvider {
|
||||
constructor(config: ProviderConfig) {
|
||||
super(config)
|
||||
}
|
||||
|
||||
async call(request: ProviderRequest): Promise<ProviderResponse> {
|
||||
const prompt = this.buildPrompt(request) // Includes validation feedback
|
||||
// ... make API call
|
||||
return {
|
||||
content: parsedResponse,
|
||||
usage: { promptTokens: 100, completionTokens: 50 },
|
||||
finishReason: 'stop',
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
60
packages/llm-client/package.json
Normal file
60
packages/llm-client/package.json
Normal file
@@ -0,0 +1,60 @@
|
||||
{
|
||||
"name": "@soroban/llm-client",
|
||||
"version": "1.0.0",
|
||||
"description": "Type-safe LLM client with multi-provider support, Zod schema validation, and retry logic",
|
||||
"main": "./dist/index.js",
|
||||
"module": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js",
|
||||
"require": "./dist/index.js"
|
||||
}
|
||||
},
|
||||
"files": [
|
||||
"dist/**/*",
|
||||
"src/**/*",
|
||||
"README.md"
|
||||
],
|
||||
"scripts": {
|
||||
"build": "tsc",
|
||||
"dev": "tsc --watch",
|
||||
"clean": "rm -rf dist",
|
||||
"type-check": "tsc --noEmit",
|
||||
"test": "vitest",
|
||||
"test:run": "vitest run"
|
||||
},
|
||||
"keywords": [
|
||||
"llm",
|
||||
"openai",
|
||||
"anthropic",
|
||||
"claude",
|
||||
"gpt",
|
||||
"ai",
|
||||
"typescript",
|
||||
"zod",
|
||||
"schema",
|
||||
"validation"
|
||||
],
|
||||
"dependencies": {},
|
||||
"peerDependencies": {
|
||||
"zod": "^4.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.0.0",
|
||||
"typescript": "^5.0.0",
|
||||
"vitest": "^1.0.0",
|
||||
"zod": "^4.1.12"
|
||||
},
|
||||
"author": "Soroban Flashcards Team",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "https://github.com/antialias/soroban-abacus-flashcards",
|
||||
"directory": "packages/llm-client"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
}
|
||||
}
|
||||
266
packages/llm-client/src/client.ts
Normal file
266
packages/llm-client/src/client.ts
Normal file
@@ -0,0 +1,266 @@
|
||||
import { z } from 'zod'
|
||||
import type {
|
||||
LLMClientConfig,
|
||||
LLMProvider,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
ValidationFeedback,
|
||||
} from './types'
|
||||
import { ProviderNotConfiguredError } from './types'
|
||||
import {
|
||||
loadConfigFromEnv,
|
||||
getProviderConfig,
|
||||
getConfiguredProviders,
|
||||
isProviderConfigured,
|
||||
} from './config'
|
||||
import { executeWithRetry } from './retry'
|
||||
import { OpenAIProvider } from './providers/openai'
|
||||
import { AnthropicProvider } from './providers/anthropic'
|
||||
|
||||
/**
|
||||
* Factory function type for creating providers
|
||||
*/
|
||||
type ProviderFactory = (config: LLMClientConfig, providerName: string) => LLMProvider
|
||||
|
||||
/**
|
||||
* Registry of provider factories
|
||||
*/
|
||||
const providerFactories: Record<string, ProviderFactory> = {
|
||||
openai: (config, name) => {
|
||||
const providerConfig = getProviderConfig(config, name)
|
||||
if (!providerConfig) throw new ProviderNotConfiguredError(name)
|
||||
return new OpenAIProvider(providerConfig)
|
||||
},
|
||||
anthropic: (config, name) => {
|
||||
const providerConfig = getProviderConfig(config, name)
|
||||
if (!providerConfig) throw new ProviderNotConfiguredError(name)
|
||||
return new AnthropicProvider(providerConfig)
|
||||
},
|
||||
}
|
||||
|
||||
/**
|
||||
* LLM Client for making type-safe LLM calls with multi-provider support
|
||||
*
|
||||
* Features:
|
||||
* - Multi-provider support (OpenAI, Anthropic, etc.)
|
||||
* - Zod schema validation for responses
|
||||
* - Zod .describe() annotations included in prompts for LLM context
|
||||
* - Retry logic with validation feedback
|
||||
* - Progress callbacks for UI updates
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* import { LLMClient } from '@soroban/llm-client'
|
||||
* import { z } from 'zod'
|
||||
*
|
||||
* const llm = new LLMClient()
|
||||
*
|
||||
* const response = await llm.call({
|
||||
* prompt: 'Extract sentiment from: "I love this!"',
|
||||
* schema: z.object({
|
||||
* sentiment: z.enum(['positive', 'negative', 'neutral'])
|
||||
* .describe('The detected sentiment'),
|
||||
* confidence: z.number()
|
||||
* .describe('Confidence score between 0 and 1'),
|
||||
* }).describe('Sentiment analysis result'),
|
||||
* })
|
||||
*
|
||||
* console.log(response.data.sentiment) // TypeScript knows this is valid
|
||||
* ```
|
||||
*/
|
||||
export class LLMClient {
|
||||
private readonly config: LLMClientConfig
|
||||
private readonly providers: Map<string, LLMProvider> = new Map()
|
||||
|
||||
/**
|
||||
* Create a new LLM client
|
||||
*
|
||||
* @param configOverrides - Optional configuration overrides
|
||||
* @param env - Environment variables (defaults to process.env)
|
||||
*/
|
||||
constructor(
|
||||
configOverrides?: Partial<LLMClientConfig>,
|
||||
env?: Record<string, string | undefined>
|
||||
) {
|
||||
const envConfig = loadConfigFromEnv(env)
|
||||
this.config = {
|
||||
...envConfig,
|
||||
...configOverrides,
|
||||
providers: {
|
||||
...envConfig.providers,
|
||||
...configOverrides?.providers,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Make a structured LLM call with schema validation
|
||||
*
|
||||
* @param request - The request configuration
|
||||
* @returns Type-safe response with validated data
|
||||
*/
|
||||
async call<T extends z.ZodType>(
|
||||
request: LLMRequest<T>
|
||||
): Promise<LLMResponse<z.infer<T>>> {
|
||||
return this.executeRequest(request)
|
||||
}
|
||||
|
||||
/**
|
||||
* Make a vision call (convenience method for requests with images)
|
||||
*
|
||||
* @param request - The request configuration (must include images)
|
||||
* @returns Type-safe response with validated data
|
||||
*/
|
||||
async vision<T extends z.ZodType>(
|
||||
request: LLMRequest<T> & { images: string[] }
|
||||
): Promise<LLMResponse<z.infer<T>>> {
|
||||
return this.executeRequest(request)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of configured providers
|
||||
*/
|
||||
getProviders(): string[] {
|
||||
return getConfiguredProviders(this.config)
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a provider is configured
|
||||
*/
|
||||
isProviderAvailable(providerName: string): boolean {
|
||||
return isProviderConfigured(this.config, providerName)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the default provider name
|
||||
*/
|
||||
getDefaultProvider(): string {
|
||||
return this.config.defaultProvider
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the default model
|
||||
*/
|
||||
getDefaultModel(providerName?: string): string {
|
||||
if (this.config.defaultModel) {
|
||||
return this.config.defaultModel
|
||||
}
|
||||
const provider = getProviderConfig(this.config, providerName)
|
||||
return provider?.defaultModel ?? 'default'
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the LLM request with retry logic
|
||||
*/
|
||||
private async executeRequest<T extends z.ZodType>(
|
||||
request: LLMRequest<T>
|
||||
): Promise<LLMResponse<z.infer<T>>> {
|
||||
const providerName = request.provider ?? this.config.defaultProvider
|
||||
const model = request.model ?? this.getDefaultModel(providerName)
|
||||
const maxRetries = request.maxRetries ?? this.config.defaultMaxRetries
|
||||
|
||||
// Get or create provider instance
|
||||
const provider = this.getOrCreateProvider(providerName)
|
||||
|
||||
// Convert Zod schema to JSON Schema using Zod v4's native method
|
||||
// This preserves .describe() annotations as "description" fields
|
||||
const jsonSchema = z.toJSONSchema(request.schema, {
|
||||
unrepresentable: 'any', // Convert unrepresentable types to {} instead of throwing
|
||||
}) as Record<string, unknown>
|
||||
|
||||
// Execute with retry logic
|
||||
const { result: providerResponse, attempts } = await executeWithRetry(
|
||||
async (validationFeedback?: ValidationFeedback) => {
|
||||
const providerRequest: ProviderRequest = {
|
||||
prompt: request.prompt,
|
||||
images: request.images,
|
||||
jsonSchema,
|
||||
model,
|
||||
validationFeedback,
|
||||
}
|
||||
|
||||
return provider.call(providerRequest)
|
||||
},
|
||||
(response) => {
|
||||
// Validate response against schema
|
||||
const parseResult = request.schema.safeParse(response.content)
|
||||
|
||||
if (!parseResult.success) {
|
||||
// Extract first error for feedback (Zod v4 uses 'issues')
|
||||
const firstIssue = parseResult.error.issues[0]
|
||||
if (firstIssue) {
|
||||
return {
|
||||
field: firstIssue.path.join('.') || 'root',
|
||||
error: firstIssue.message,
|
||||
received: response.content,
|
||||
}
|
||||
}
|
||||
return {
|
||||
field: 'root',
|
||||
error: 'Validation failed',
|
||||
received: response.content,
|
||||
}
|
||||
}
|
||||
|
||||
return null // Valid
|
||||
},
|
||||
{
|
||||
maxRetries,
|
||||
onProgress: request.onProgress,
|
||||
}
|
||||
)
|
||||
|
||||
// Parse the validated response
|
||||
const parseResult = request.schema.safeParse(providerResponse.content)
|
||||
if (!parseResult.success) {
|
||||
// Should not happen after retry validation, but handle gracefully
|
||||
throw new Error('Validation failed after retry')
|
||||
}
|
||||
|
||||
return {
|
||||
data: parseResult.data,
|
||||
usage: {
|
||||
promptTokens: providerResponse.usage.promptTokens,
|
||||
completionTokens: providerResponse.usage.completionTokens,
|
||||
totalTokens:
|
||||
providerResponse.usage.promptTokens + providerResponse.usage.completionTokens,
|
||||
},
|
||||
attempts,
|
||||
provider: providerName,
|
||||
model,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get or create a provider instance
|
||||
*/
|
||||
private getOrCreateProvider(providerName: string): LLMProvider {
|
||||
const name = providerName.toLowerCase()
|
||||
|
||||
// Check cache
|
||||
const cached = this.providers.get(name)
|
||||
if (cached) {
|
||||
return cached
|
||||
}
|
||||
|
||||
// Check if provider is configured
|
||||
if (!isProviderConfigured(this.config, name)) {
|
||||
throw new ProviderNotConfiguredError(name)
|
||||
}
|
||||
|
||||
// Get factory
|
||||
const factory = providerFactories[name]
|
||||
if (!factory) {
|
||||
throw new Error(
|
||||
`Unknown provider: ${name}. Supported providers: ${Object.keys(providerFactories).join(', ')}`
|
||||
)
|
||||
}
|
||||
|
||||
// Create and cache provider
|
||||
const provider = factory(this.config, name)
|
||||
this.providers.set(name, provider)
|
||||
|
||||
return provider
|
||||
}
|
||||
}
|
||||
203
packages/llm-client/src/config.test.ts
Normal file
203
packages/llm-client/src/config.test.ts
Normal file
@@ -0,0 +1,203 @@
|
||||
import { describe, it, expect } from 'vitest'
|
||||
import {
|
||||
loadConfigFromEnv,
|
||||
getProviderConfig,
|
||||
getConfiguredProviders,
|
||||
isProviderConfigured,
|
||||
} from './config'
|
||||
|
||||
describe('config', () => {
|
||||
describe('loadConfigFromEnv', () => {
|
||||
it('should load default configuration when no env vars set', () => {
|
||||
const config = loadConfigFromEnv({})
|
||||
|
||||
expect(config.defaultProvider).toBe('openai')
|
||||
expect(config.defaultMaxRetries).toBe(2)
|
||||
expect(Object.keys(config.providers)).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should load default provider from env', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_DEFAULT_PROVIDER: 'anthropic',
|
||||
})
|
||||
|
||||
expect(config.defaultProvider).toBe('anthropic')
|
||||
})
|
||||
|
||||
it('should load default model from env', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_DEFAULT_MODEL: 'gpt-5',
|
||||
})
|
||||
|
||||
expect(config.defaultModel).toBe('gpt-5')
|
||||
})
|
||||
|
||||
it('should load max retries from env', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_DEFAULT_MAX_RETRIES: '5',
|
||||
})
|
||||
|
||||
expect(config.defaultMaxRetries).toBe(5)
|
||||
})
|
||||
|
||||
it('should load OpenAI provider configuration', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_OPENAI_API_KEY: 'sk-test-key',
|
||||
LLM_OPENAI_BASE_URL: 'https://custom.openai.com',
|
||||
LLM_OPENAI_DEFAULT_MODEL: 'gpt-4-turbo',
|
||||
})
|
||||
|
||||
expect(config.providers.openai).toBeDefined()
|
||||
expect(config.providers.openai.apiKey).toBe('sk-test-key')
|
||||
expect(config.providers.openai.baseUrl).toBe('https://custom.openai.com')
|
||||
expect(config.providers.openai.defaultModel).toBe('gpt-4-turbo')
|
||||
})
|
||||
|
||||
it('should use default base URL for OpenAI if not provided', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_OPENAI_API_KEY: 'sk-test-key',
|
||||
})
|
||||
|
||||
expect(config.providers.openai.baseUrl).toBe('https://api.openai.com/v1')
|
||||
})
|
||||
|
||||
it('should use default model for OpenAI if not provided', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_OPENAI_API_KEY: 'sk-test-key',
|
||||
})
|
||||
|
||||
expect(config.providers.openai.defaultModel).toBe('gpt-4o')
|
||||
})
|
||||
|
||||
it('should load Anthropic provider configuration', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_ANTHROPIC_API_KEY: 'sk-ant-test',
|
||||
})
|
||||
|
||||
expect(config.providers.anthropic).toBeDefined()
|
||||
expect(config.providers.anthropic.apiKey).toBe('sk-ant-test')
|
||||
expect(config.providers.anthropic.baseUrl).toBe('https://api.anthropic.com/v1')
|
||||
expect(config.providers.anthropic.defaultModel).toBe('claude-sonnet-4-20250514')
|
||||
})
|
||||
|
||||
it('should load multiple providers', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_OPENAI_API_KEY: 'sk-openai',
|
||||
LLM_ANTHROPIC_API_KEY: 'sk-anthropic',
|
||||
})
|
||||
|
||||
expect(Object.keys(config.providers)).toHaveLength(2)
|
||||
expect(config.providers.openai).toBeDefined()
|
||||
expect(config.providers.anthropic).toBeDefined()
|
||||
})
|
||||
|
||||
it('should discover custom providers from API key pattern', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_CUSTOM_API_KEY: 'sk-custom',
|
||||
LLM_CUSTOM_BASE_URL: 'https://api.custom.com',
|
||||
LLM_CUSTOM_DEFAULT_MODEL: 'custom-model',
|
||||
})
|
||||
|
||||
expect(config.providers.custom).toBeDefined()
|
||||
expect(config.providers.custom.apiKey).toBe('sk-custom')
|
||||
expect(config.providers.custom.baseUrl).toBe('https://api.custom.com')
|
||||
expect(config.providers.custom.defaultModel).toBe('custom-model')
|
||||
})
|
||||
|
||||
it('should not create provider config without API key', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_OPENAI_BASE_URL: 'https://custom.com',
|
||||
})
|
||||
|
||||
expect(config.providers.openai).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getProviderConfig', () => {
|
||||
it('should return provider config by name', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_OPENAI_API_KEY: 'sk-test',
|
||||
})
|
||||
|
||||
const provider = getProviderConfig(config, 'openai')
|
||||
|
||||
expect(provider).toBeDefined()
|
||||
expect(provider?.apiKey).toBe('sk-test')
|
||||
})
|
||||
|
||||
it('should return default provider config when name not specified', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_DEFAULT_PROVIDER: 'anthropic',
|
||||
LLM_ANTHROPIC_API_KEY: 'sk-test',
|
||||
})
|
||||
|
||||
const provider = getProviderConfig(config)
|
||||
|
||||
expect(provider).toBeDefined()
|
||||
expect(provider?.name).toBe('anthropic')
|
||||
})
|
||||
|
||||
it('should be case insensitive', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_OPENAI_API_KEY: 'sk-test',
|
||||
})
|
||||
|
||||
const provider = getProviderConfig(config, 'OpenAI')
|
||||
|
||||
expect(provider).toBeDefined()
|
||||
})
|
||||
|
||||
it('should return undefined for non-existent provider', () => {
|
||||
const config = loadConfigFromEnv({})
|
||||
|
||||
const provider = getProviderConfig(config, 'nonexistent')
|
||||
|
||||
expect(provider).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getConfiguredProviders', () => {
|
||||
it('should return empty array when no providers configured', () => {
|
||||
const config = loadConfigFromEnv({})
|
||||
|
||||
expect(getConfiguredProviders(config)).toEqual([])
|
||||
})
|
||||
|
||||
it('should return list of configured provider names', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_OPENAI_API_KEY: 'sk-openai',
|
||||
LLM_ANTHROPIC_API_KEY: 'sk-anthropic',
|
||||
})
|
||||
|
||||
const providers = getConfiguredProviders(config)
|
||||
|
||||
expect(providers).toContain('openai')
|
||||
expect(providers).toContain('anthropic')
|
||||
expect(providers).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isProviderConfigured', () => {
|
||||
it('should return true for configured provider', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_OPENAI_API_KEY: 'sk-test',
|
||||
})
|
||||
|
||||
expect(isProviderConfigured(config, 'openai')).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for non-configured provider', () => {
|
||||
const config = loadConfigFromEnv({})
|
||||
|
||||
expect(isProviderConfigured(config, 'openai')).toBe(false)
|
||||
})
|
||||
|
||||
it('should be case insensitive', () => {
|
||||
const config = loadConfigFromEnv({
|
||||
LLM_OPENAI_API_KEY: 'sk-test',
|
||||
})
|
||||
|
||||
expect(isProviderConfigured(config, 'OPENAI')).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
144
packages/llm-client/src/config.ts
Normal file
144
packages/llm-client/src/config.ts
Normal file
@@ -0,0 +1,144 @@
|
||||
import type { LLMClientConfig, ProviderConfig } from './types'
|
||||
|
||||
/**
|
||||
* Known provider defaults
|
||||
*/
|
||||
const PROVIDER_DEFAULTS: Record<string, Partial<ProviderConfig>> = {
|
||||
openai: {
|
||||
baseUrl: 'https://api.openai.com/v1',
|
||||
defaultModel: 'gpt-4o',
|
||||
},
|
||||
anthropic: {
|
||||
baseUrl: 'https://api.anthropic.com/v1',
|
||||
defaultModel: 'claude-sonnet-4-20250514',
|
||||
},
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse provider configuration from environment variables
|
||||
*
|
||||
* Env var convention:
|
||||
* - LLM_{PROVIDER}_API_KEY - API key (required)
|
||||
* - LLM_{PROVIDER}_BASE_URL - Base URL (optional, has defaults)
|
||||
* - LLM_{PROVIDER}_DEFAULT_MODEL - Default model (optional)
|
||||
* - LLM_{PROVIDER}_{OPTION} - Additional options
|
||||
*/
|
||||
function parseProviderFromEnv(
|
||||
providerName: string,
|
||||
env: Record<string, string | undefined>
|
||||
): ProviderConfig | null {
|
||||
const prefix = `LLM_${providerName.toUpperCase()}_`
|
||||
|
||||
const apiKey = env[`${prefix}API_KEY`]
|
||||
if (!apiKey) {
|
||||
return null
|
||||
}
|
||||
|
||||
const defaults = PROVIDER_DEFAULTS[providerName.toLowerCase()] ?? {}
|
||||
|
||||
const baseUrl =
|
||||
env[`${prefix}BASE_URL`] ?? defaults.baseUrl ?? `https://api.${providerName}.com/v1`
|
||||
|
||||
const defaultModel =
|
||||
env[`${prefix}DEFAULT_MODEL`] ?? defaults.defaultModel ?? 'default'
|
||||
|
||||
// Collect any additional options
|
||||
const options: Record<string, unknown> = {}
|
||||
for (const [key, value] of Object.entries(env)) {
|
||||
if (
|
||||
key.startsWith(prefix) &&
|
||||
!['API_KEY', 'BASE_URL', 'DEFAULT_MODEL'].includes(key.slice(prefix.length))
|
||||
) {
|
||||
const optionName = key.slice(prefix.length).toLowerCase()
|
||||
options[optionName] = value
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
name: providerName.toLowerCase(),
|
||||
apiKey,
|
||||
baseUrl,
|
||||
defaultModel,
|
||||
options: Object.keys(options).length > 0 ? options : undefined,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load LLM client configuration from environment variables
|
||||
*
|
||||
* Env vars:
|
||||
* - LLM_DEFAULT_PROVIDER - Default provider to use (required)
|
||||
* - LLM_DEFAULT_MODEL - Override default model (optional)
|
||||
* - LLM_DEFAULT_MAX_RETRIES - Default max retries (optional, default: 2)
|
||||
* - LLM_{PROVIDER}_* - Provider-specific configuration
|
||||
*
|
||||
* @param env - Environment variables (defaults to process.env)
|
||||
*/
|
||||
export function loadConfigFromEnv(
|
||||
env: Record<string, string | undefined> = process.env
|
||||
): LLMClientConfig {
|
||||
const defaultProvider = env.LLM_DEFAULT_PROVIDER?.toLowerCase() ?? 'openai'
|
||||
const defaultModel = env.LLM_DEFAULT_MODEL
|
||||
const defaultMaxRetries = parseInt(env.LLM_DEFAULT_MAX_RETRIES ?? '2', 10)
|
||||
|
||||
// Discover configured providers by scanning for API keys
|
||||
const providers: Record<string, ProviderConfig> = {}
|
||||
|
||||
// Check known providers
|
||||
const knownProviders = ['openai', 'anthropic']
|
||||
for (const provider of knownProviders) {
|
||||
const config = parseProviderFromEnv(provider, env)
|
||||
if (config) {
|
||||
providers[provider] = config
|
||||
}
|
||||
}
|
||||
|
||||
// Also check for any LLM_*_API_KEY pattern to discover custom providers
|
||||
for (const key of Object.keys(env)) {
|
||||
const match = key.match(/^LLM_([A-Z0-9_]+)_API_KEY$/)
|
||||
if (match) {
|
||||
const providerName = match[1].toLowerCase()
|
||||
if (!providers[providerName]) {
|
||||
const config = parseProviderFromEnv(providerName, env)
|
||||
if (config) {
|
||||
providers[providerName] = config
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
defaultProvider,
|
||||
defaultModel,
|
||||
providers,
|
||||
defaultMaxRetries,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a specific provider config
|
||||
*/
|
||||
export function getProviderConfig(
|
||||
config: LLMClientConfig,
|
||||
providerName?: string
|
||||
): ProviderConfig | undefined {
|
||||
const name = providerName?.toLowerCase() ?? config.defaultProvider
|
||||
return config.providers[name]
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a provider is configured
|
||||
*/
|
||||
export function isProviderConfigured(
|
||||
config: LLMClientConfig,
|
||||
providerName: string
|
||||
): boolean {
|
||||
return providerName.toLowerCase() in config.providers
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of configured provider names
|
||||
*/
|
||||
export function getConfiguredProviders(config: LLMClientConfig): string[] {
|
||||
return Object.keys(config.providers)
|
||||
}
|
||||
72
packages/llm-client/src/index.ts
Normal file
72
packages/llm-client/src/index.ts
Normal file
@@ -0,0 +1,72 @@
|
||||
/**
|
||||
* @soroban/llm-client
|
||||
*
|
||||
* Type-safe LLM client with multi-provider support, Zod schema validation,
|
||||
* and retry logic with validation feedback.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* import { LLMClient } from '@soroban/llm-client'
|
||||
* import { z } from 'zod'
|
||||
*
|
||||
* const llm = new LLMClient()
|
||||
*
|
||||
* const SentimentSchema = z.object({
|
||||
* sentiment: z.enum(['positive', 'negative', 'neutral']),
|
||||
* confidence: z.number().min(0).max(1),
|
||||
* })
|
||||
*
|
||||
* const response = await llm.call({
|
||||
* prompt: 'Analyze sentiment: "I love this product!"',
|
||||
* schema: SentimentSchema,
|
||||
* onProgress: (p) => console.log(p.message),
|
||||
* })
|
||||
*
|
||||
* console.log(response.data.sentiment) // 'positive'
|
||||
* ```
|
||||
*
|
||||
* @packageDocumentation
|
||||
*/
|
||||
|
||||
// Main client
|
||||
export { LLMClient } from './client'
|
||||
|
||||
// Types
|
||||
export type {
|
||||
LLMClientConfig,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
LLMProgress,
|
||||
LLMProvider,
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderResponse,
|
||||
ValidationFeedback,
|
||||
} from './types'
|
||||
|
||||
// Errors
|
||||
export {
|
||||
LLMValidationError,
|
||||
LLMApiError,
|
||||
LLMTruncationError,
|
||||
LLMContentFilterError,
|
||||
LLMJsonParseError,
|
||||
ProviderNotConfiguredError,
|
||||
} from './types'
|
||||
|
||||
// Config utilities
|
||||
export {
|
||||
loadConfigFromEnv,
|
||||
getProviderConfig,
|
||||
getConfiguredProviders,
|
||||
isProviderConfigured,
|
||||
} from './config'
|
||||
|
||||
// Retry utilities (for advanced usage)
|
||||
export { executeWithRetry, buildFeedbackPrompt, isRetryableError, getRetryDelay } from './retry'
|
||||
export type { RetryOptions } from './retry'
|
||||
|
||||
// Providers (for advanced usage / custom providers)
|
||||
export { BaseProvider } from './providers/base'
|
||||
export { OpenAIProvider } from './providers/openai'
|
||||
export { AnthropicProvider } from './providers/anthropic'
|
||||
218
packages/llm-client/src/providers/anthropic.ts
Normal file
218
packages/llm-client/src/providers/anthropic.ts
Normal file
@@ -0,0 +1,218 @@
|
||||
import type { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
|
||||
import { LLMApiError, LLMTruncationError, LLMContentFilterError } from '../types'
|
||||
import { BaseProvider } from './base'
|
||||
|
||||
/**
|
||||
* Anthropic content block
|
||||
*/
|
||||
interface ContentBlock {
|
||||
type: 'text' | 'image' | 'tool_use' | 'tool_result'
|
||||
text?: string
|
||||
source?: {
|
||||
type: 'base64'
|
||||
media_type: string
|
||||
data: string
|
||||
}
|
||||
id?: string
|
||||
name?: string
|
||||
input?: unknown
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic message
|
||||
*/
|
||||
interface AnthropicMessage {
|
||||
role: 'user' | 'assistant'
|
||||
content: string | ContentBlock[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic messages response
|
||||
*/
|
||||
interface MessagesResponse {
|
||||
id: string
|
||||
type: 'message' | 'error'
|
||||
role: 'assistant'
|
||||
content: ContentBlock[]
|
||||
stop_reason: 'end_turn' | 'max_tokens' | 'stop_sequence' | 'tool_use' | null
|
||||
usage: {
|
||||
input_tokens: number
|
||||
output_tokens: number
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic error response structure
|
||||
*/
|
||||
interface AnthropicErrorResponse {
|
||||
type?: 'error'
|
||||
error?: {
|
||||
type?: string
|
||||
message?: string
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic provider implementation
|
||||
*
|
||||
* Uses the Messages API with tool use for structured output.
|
||||
* Falls back to JSON parsing from text if tool use is not available.
|
||||
*/
|
||||
export class AnthropicProvider extends BaseProvider {
|
||||
constructor(config: ProviderConfig) {
|
||||
super(config)
|
||||
}
|
||||
|
||||
async call(request: ProviderRequest): Promise<ProviderResponse> {
|
||||
const prompt = this.buildPrompt(request)
|
||||
const messages = this.buildMessages(prompt, request.images)
|
||||
|
||||
// Use tool use for structured output
|
||||
const tool = {
|
||||
name: 'provide_response',
|
||||
description:
|
||||
'Provide the response in the required JSON format. Always use this tool.',
|
||||
input_schema: request.jsonSchema,
|
||||
}
|
||||
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model: request.model,
|
||||
max_tokens: 4096,
|
||||
messages,
|
||||
tools: [tool],
|
||||
tool_choice: { type: 'tool', name: 'provide_response' },
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.config.baseUrl}/messages`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'x-api-key': this.config.apiKey,
|
||||
'anthropic-version': '2023-06-01',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
let errorMessage = errorText
|
||||
let errorType: string | undefined
|
||||
|
||||
try {
|
||||
const errorJson = JSON.parse(errorText) as AnthropicErrorResponse
|
||||
errorMessage = errorJson.error?.message ?? errorText
|
||||
errorType = errorJson.error?.type
|
||||
} catch {
|
||||
// Keep original text
|
||||
}
|
||||
|
||||
// Parse Retry-After header for rate limits
|
||||
const retryAfterMs = this.parseRetryAfter(response.headers)
|
||||
|
||||
// Check for specific Anthropic error types
|
||||
if (errorType === 'invalid_request_error' && errorMessage.includes('content filtering')) {
|
||||
throw new LLMContentFilterError(this.name, errorMessage)
|
||||
}
|
||||
|
||||
throw new LLMApiError(this.name, response.status, errorMessage, retryAfterMs)
|
||||
}
|
||||
|
||||
const data = (await response.json()) as MessagesResponse
|
||||
|
||||
// Check for max_tokens (truncation)
|
||||
if (data.stop_reason === 'max_tokens') {
|
||||
// Try to extract partial content
|
||||
const toolUseBlock = data.content.find((block) => block.type === 'tool_use')
|
||||
const textBlock = data.content.find((block) => block.type === 'text')
|
||||
const partialContent = toolUseBlock?.input ?? textBlock?.text ?? null
|
||||
|
||||
throw new LLMTruncationError(this.name, partialContent)
|
||||
}
|
||||
|
||||
// Find the tool use block
|
||||
const toolUseBlock = data.content.find((block) => block.type === 'tool_use')
|
||||
|
||||
if (!toolUseBlock || toolUseBlock.type !== 'tool_use') {
|
||||
// Fall back to text content
|
||||
const textBlock = data.content.find((block) => block.type === 'text')
|
||||
if (textBlock && textBlock.text) {
|
||||
// Check if it's a refusal
|
||||
const lowerText = textBlock.text.toLowerCase()
|
||||
if (
|
||||
lowerText.includes("i can't") ||
|
||||
lowerText.includes("i cannot") ||
|
||||
lowerText.includes("i'm not able") ||
|
||||
lowerText.includes("i am not able")
|
||||
) {
|
||||
throw new LLMContentFilterError(this.name, textBlock.text)
|
||||
}
|
||||
|
||||
return {
|
||||
content: this.parseJsonResponse(textBlock.text),
|
||||
usage: {
|
||||
promptTokens: data.usage?.input_tokens ?? 0,
|
||||
completionTokens: data.usage?.output_tokens ?? 0,
|
||||
},
|
||||
finishReason: data.stop_reason ?? 'unknown',
|
||||
}
|
||||
}
|
||||
throw new LLMApiError(this.name, 500, 'No tool use or text content in response')
|
||||
}
|
||||
|
||||
return {
|
||||
content: toolUseBlock.input,
|
||||
usage: {
|
||||
promptTokens: data.usage?.input_tokens ?? 0,
|
||||
completionTokens: data.usage?.output_tokens ?? 0,
|
||||
},
|
||||
finishReason: data.stop_reason ?? 'unknown',
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build messages for the request
|
||||
*/
|
||||
private buildMessages(prompt: string, images?: string[]): AnthropicMessage[] {
|
||||
// If no images, simple text message
|
||||
if (!images || images.length === 0) {
|
||||
return [
|
||||
{
|
||||
role: 'user',
|
||||
content: prompt,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
// Vision request: combine images and text
|
||||
const content: ContentBlock[] = []
|
||||
|
||||
// Add images first
|
||||
for (const imageUrl of images) {
|
||||
// Parse data URL to extract base64 and media type
|
||||
const match = imageUrl.match(/^data:([^;]+);base64,(.+)$/)
|
||||
if (match) {
|
||||
content.push({
|
||||
type: 'image',
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: match[1],
|
||||
data: match[2],
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Add text prompt
|
||||
content.push({
|
||||
type: 'text',
|
||||
text: prompt,
|
||||
})
|
||||
|
||||
return [
|
||||
{
|
||||
role: 'user',
|
||||
content,
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
279
packages/llm-client/src/providers/base.test.ts
Normal file
279
packages/llm-client/src/providers/base.test.ts
Normal file
@@ -0,0 +1,279 @@
|
||||
import { describe, it, expect } from 'vitest'
|
||||
import { z } from 'zod'
|
||||
import type { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
|
||||
import { BaseProvider } from './base'
|
||||
|
||||
// Concrete implementation for testing
|
||||
class TestProvider extends BaseProvider {
|
||||
async call(_request: ProviderRequest): Promise<ProviderResponse> {
|
||||
throw new Error('Not implemented')
|
||||
}
|
||||
|
||||
// Expose protected methods for testing
|
||||
public testBuildPrompt(request: ProviderRequest): string {
|
||||
return this.buildPrompt(request)
|
||||
}
|
||||
|
||||
public testBuildSchemaDocumentation(schema: Record<string, unknown>): string {
|
||||
return this.buildSchemaDocumentation(schema)
|
||||
}
|
||||
|
||||
public testExtractFieldDescriptions(
|
||||
schema: Record<string, unknown>,
|
||||
path?: string
|
||||
): string[] {
|
||||
return this.extractFieldDescriptions(schema, path)
|
||||
}
|
||||
}
|
||||
|
||||
const testConfig: ProviderConfig = {
|
||||
name: 'test',
|
||||
apiKey: 'test-key',
|
||||
baseUrl: 'https://test.com',
|
||||
defaultModel: 'test-model',
|
||||
}
|
||||
|
||||
describe('BaseProvider', () => {
|
||||
describe('buildSchemaDocumentation', () => {
|
||||
it('should include schema in documentation', () => {
|
||||
const provider = new TestProvider(testConfig)
|
||||
const schema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
name: { type: 'string' },
|
||||
},
|
||||
}
|
||||
|
||||
const docs = provider.testBuildSchemaDocumentation(schema)
|
||||
|
||||
expect(docs).toContain('## Response Format')
|
||||
expect(docs).toContain('JSON Schema')
|
||||
expect(docs).toContain('"type": "object"')
|
||||
})
|
||||
|
||||
it('should extract field descriptions from schema', () => {
|
||||
const provider = new TestProvider(testConfig)
|
||||
const schema = {
|
||||
type: 'object',
|
||||
description: 'The analysis result',
|
||||
properties: {
|
||||
sentiment: {
|
||||
type: 'string',
|
||||
description: 'The detected sentiment of the text',
|
||||
},
|
||||
confidence: {
|
||||
type: 'number',
|
||||
description: 'Confidence score between 0 and 1',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const docs = provider.testBuildSchemaDocumentation(schema)
|
||||
|
||||
expect(docs).toContain('### Field Descriptions')
|
||||
expect(docs).toContain('**Response**: The analysis result')
|
||||
expect(docs).toContain('**sentiment**: The detected sentiment of the text')
|
||||
expect(docs).toContain('**confidence**: Confidence score between 0 and 1')
|
||||
})
|
||||
})
|
||||
|
||||
describe('extractFieldDescriptions', () => {
|
||||
it('should extract top-level description', () => {
|
||||
const provider = new TestProvider(testConfig)
|
||||
const schema = {
|
||||
type: 'object',
|
||||
description: 'Root description',
|
||||
}
|
||||
|
||||
const descriptions = provider.testExtractFieldDescriptions(schema)
|
||||
|
||||
expect(descriptions).toContain('- **Response**: Root description')
|
||||
})
|
||||
|
||||
it('should extract nested property descriptions', () => {
|
||||
const provider = new TestProvider(testConfig)
|
||||
const schema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
user: {
|
||||
type: 'object',
|
||||
description: 'User information',
|
||||
properties: {
|
||||
name: {
|
||||
type: 'string',
|
||||
description: 'Full name of the user',
|
||||
},
|
||||
age: {
|
||||
type: 'number',
|
||||
description: 'Age in years',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const descriptions = provider.testExtractFieldDescriptions(schema)
|
||||
|
||||
expect(descriptions).toContain('- **user**: User information')
|
||||
expect(descriptions).toContain('- **user.name**: Full name of the user')
|
||||
expect(descriptions).toContain('- **user.age**: Age in years')
|
||||
})
|
||||
|
||||
it('should extract array item descriptions', () => {
|
||||
const provider = new TestProvider(testConfig)
|
||||
const schema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
items: {
|
||||
type: 'array',
|
||||
description: 'List of items',
|
||||
items: {
|
||||
type: 'object',
|
||||
description: 'A single item',
|
||||
properties: {
|
||||
id: {
|
||||
type: 'number',
|
||||
description: 'Unique identifier',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const descriptions = provider.testExtractFieldDescriptions(schema)
|
||||
|
||||
expect(descriptions).toContain('- **items**: List of items')
|
||||
expect(descriptions).toContain('- **items[]**: A single item')
|
||||
expect(descriptions).toContain('- **items[].id**: Unique identifier')
|
||||
})
|
||||
|
||||
it('should handle anyOf/oneOf for nullable types', () => {
|
||||
const provider = new TestProvider(testConfig)
|
||||
const schema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
value: {
|
||||
anyOf: [
|
||||
{ type: 'number', description: 'Numeric value' },
|
||||
{ type: 'null' },
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const descriptions = provider.testExtractFieldDescriptions(schema)
|
||||
|
||||
expect(descriptions).toContain('- **value**: Numeric value')
|
||||
})
|
||||
|
||||
it('should skip fields without descriptions', () => {
|
||||
const provider = new TestProvider(testConfig)
|
||||
const schema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
withDesc: {
|
||||
type: 'string',
|
||||
description: 'Has description',
|
||||
},
|
||||
withoutDesc: {
|
||||
type: 'number',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const descriptions = provider.testExtractFieldDescriptions(schema)
|
||||
|
||||
expect(descriptions).toHaveLength(1)
|
||||
expect(descriptions[0]).toContain('withDesc')
|
||||
})
|
||||
})
|
||||
|
||||
describe('buildPrompt with Zod schemas', () => {
|
||||
it('should include Zod describe() annotations in prompt', () => {
|
||||
const provider = new TestProvider(testConfig)
|
||||
|
||||
const schema = z
|
||||
.object({
|
||||
sentiment: z
|
||||
.enum(['positive', 'negative', 'neutral'])
|
||||
.describe('The overall sentiment detected in the text'),
|
||||
confidence: z
|
||||
.number()
|
||||
.min(0)
|
||||
.max(1)
|
||||
.describe('How confident the model is in its assessment (0-1)'),
|
||||
keywords: z
|
||||
.array(z.string().describe('A relevant keyword'))
|
||||
.describe('Key terms that influenced the sentiment'),
|
||||
})
|
||||
.describe('Sentiment analysis result')
|
||||
|
||||
// Use Zod v4's native toJSONSchema
|
||||
const jsonSchema = z.toJSONSchema(schema) as Record<string, unknown>
|
||||
|
||||
const prompt = provider.testBuildPrompt({
|
||||
prompt: 'Analyze the sentiment of: "I love this product!"',
|
||||
jsonSchema,
|
||||
model: 'test-model',
|
||||
})
|
||||
|
||||
// Check original prompt is included
|
||||
expect(prompt).toContain('Analyze the sentiment of')
|
||||
|
||||
// Check descriptions are extracted
|
||||
expect(prompt).toContain('Sentiment analysis result')
|
||||
expect(prompt).toContain('The overall sentiment detected in the text')
|
||||
expect(prompt).toContain('How confident the model is in its assessment')
|
||||
expect(prompt).toContain('Key terms that influenced the sentiment')
|
||||
})
|
||||
|
||||
it('should include validation feedback when retrying', () => {
|
||||
const provider = new TestProvider(testConfig)
|
||||
|
||||
const schema = z.object({
|
||||
value: z.number().describe('A numeric value'),
|
||||
})
|
||||
|
||||
// Use Zod v4's native toJSONSchema
|
||||
const jsonSchema = z.toJSONSchema(schema) as Record<string, unknown>
|
||||
|
||||
const prompt = provider.testBuildPrompt({
|
||||
prompt: 'Extract the number',
|
||||
jsonSchema,
|
||||
model: 'test-model',
|
||||
validationFeedback: {
|
||||
field: 'value',
|
||||
error: 'Expected number, received string',
|
||||
received: 'five',
|
||||
},
|
||||
})
|
||||
|
||||
// Check validation feedback is included
|
||||
expect(prompt).toContain('PREVIOUS ATTEMPT HAD VALIDATION ERROR')
|
||||
expect(prompt).toContain('Field: value')
|
||||
expect(prompt).toContain('Expected number, received string')
|
||||
expect(prompt).toContain('Received: "five"')
|
||||
})
|
||||
})
|
||||
|
||||
describe('parseJsonResponse', () => {
|
||||
it('should parse plain JSON', () => {
|
||||
const provider = new TestProvider(testConfig)
|
||||
const result = provider['parseJsonResponse']('{"value": 42}')
|
||||
expect(result).toEqual({ value: 42 })
|
||||
})
|
||||
|
||||
it('should parse JSON in markdown code blocks', () => {
|
||||
const provider = new TestProvider(testConfig)
|
||||
const result = provider['parseJsonResponse']('```json\n{"value": 42}\n```')
|
||||
expect(result).toEqual({ value: 42 })
|
||||
})
|
||||
|
||||
it('should parse JSON in plain code blocks', () => {
|
||||
const provider = new TestProvider(testConfig)
|
||||
const result = provider['parseJsonResponse']('```\n{"value": 42}\n```')
|
||||
expect(result).toEqual({ value: 42 })
|
||||
})
|
||||
})
|
||||
})
|
||||
163
packages/llm-client/src/providers/base.ts
Normal file
163
packages/llm-client/src/providers/base.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
import type {
|
||||
LLMProvider,
|
||||
ProviderConfig,
|
||||
ProviderRequest,
|
||||
ProviderResponse,
|
||||
} from '../types'
|
||||
import { LLMJsonParseError } from '../types'
|
||||
import { buildFeedbackPrompt } from '../retry'
|
||||
|
||||
/**
|
||||
* Base class for LLM providers
|
||||
*
|
||||
* Provides common functionality for building prompts and handling
|
||||
* validation feedback.
|
||||
*/
|
||||
export abstract class BaseProvider implements LLMProvider {
|
||||
constructor(protected readonly config: ProviderConfig) {}
|
||||
|
||||
get name(): string {
|
||||
return this.config.name
|
||||
}
|
||||
|
||||
/**
|
||||
* Make an LLM call
|
||||
*/
|
||||
abstract call(request: ProviderRequest): Promise<ProviderResponse>
|
||||
|
||||
/**
|
||||
* Build the prompt with schema context and validation feedback
|
||||
*
|
||||
* The prompt is structured as:
|
||||
* 1. User's original prompt
|
||||
* 2. Schema documentation (extracted from Zod .describe() annotations)
|
||||
* 3. Validation feedback (if retrying after a failed attempt)
|
||||
*/
|
||||
protected buildPrompt(request: ProviderRequest): string {
|
||||
let prompt = request.prompt
|
||||
|
||||
// Add schema documentation
|
||||
prompt += this.buildSchemaDocumentation(request.jsonSchema)
|
||||
|
||||
// Add validation feedback if retrying
|
||||
if (request.validationFeedback) {
|
||||
prompt += buildFeedbackPrompt(request.validationFeedback)
|
||||
}
|
||||
|
||||
return prompt
|
||||
}
|
||||
|
||||
/**
|
||||
* Build human-readable documentation from the JSON schema
|
||||
*
|
||||
* Extracts descriptions from Zod's .describe() annotations and formats
|
||||
* them as clear instructions for the LLM.
|
||||
*/
|
||||
protected buildSchemaDocumentation(schema: Record<string, unknown>): string {
|
||||
const docs: string[] = []
|
||||
docs.push('\n\n## Response Format\n')
|
||||
docs.push('Respond with JSON matching the following structure:\n')
|
||||
|
||||
// Extract and format field descriptions
|
||||
const fieldDocs = this.extractFieldDescriptions(schema)
|
||||
if (fieldDocs.length > 0) {
|
||||
docs.push('\n### Field Descriptions\n')
|
||||
docs.push(fieldDocs.join('\n'))
|
||||
}
|
||||
|
||||
// Include the schema structure for reference
|
||||
docs.push('\n\n### JSON Schema\n```json\n')
|
||||
docs.push(JSON.stringify(schema, null, 2))
|
||||
docs.push('\n```')
|
||||
|
||||
return docs.join('')
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively extract field descriptions from JSON schema
|
||||
*/
|
||||
protected extractFieldDescriptions(
|
||||
schema: Record<string, unknown>,
|
||||
path: string = ''
|
||||
): string[] {
|
||||
const descriptions: string[] = []
|
||||
|
||||
// Handle schema description at current level
|
||||
if (typeof schema.description === 'string' && schema.description) {
|
||||
const fieldName = path || 'Response'
|
||||
descriptions.push(`- **${fieldName}**: ${schema.description}`)
|
||||
}
|
||||
|
||||
// Handle object properties
|
||||
if (schema.properties && typeof schema.properties === 'object') {
|
||||
const properties = schema.properties as Record<string, Record<string, unknown>>
|
||||
for (const [key, value] of Object.entries(properties)) {
|
||||
const fieldPath = path ? `${path}.${key}` : key
|
||||
descriptions.push(...this.extractFieldDescriptions(value, fieldPath))
|
||||
}
|
||||
}
|
||||
|
||||
// Handle array items
|
||||
if (schema.items && typeof schema.items === 'object') {
|
||||
const itemPath = path ? `${path}[]` : 'items'
|
||||
descriptions.push(
|
||||
...this.extractFieldDescriptions(schema.items as Record<string, unknown>, itemPath)
|
||||
)
|
||||
}
|
||||
|
||||
// Handle anyOf/oneOf (for nullable types, unions, etc.)
|
||||
for (const key of ['anyOf', 'oneOf', 'allOf'] as const) {
|
||||
const variants = schema[key]
|
||||
if (Array.isArray(variants)) {
|
||||
for (const variant of variants) {
|
||||
if (variant && typeof variant === 'object') {
|
||||
descriptions.push(
|
||||
...this.extractFieldDescriptions(variant as Record<string, unknown>, path)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return descriptions
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse JSON response from LLM
|
||||
*/
|
||||
protected parseJsonResponse(content: string): unknown {
|
||||
// Handle markdown code blocks
|
||||
const jsonMatch = content.match(/```(?:json)?\s*([\s\S]*?)```/)
|
||||
const jsonStr = jsonMatch ? jsonMatch[1].trim() : content.trim()
|
||||
|
||||
try {
|
||||
return JSON.parse(jsonStr)
|
||||
} catch (e) {
|
||||
const errorMessage = e instanceof Error ? e.message : 'Unknown parse error'
|
||||
throw new LLMJsonParseError(content, errorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse Retry-After header value
|
||||
* @returns milliseconds to wait, or undefined if not present
|
||||
*/
|
||||
protected parseRetryAfter(headers: Headers): number | undefined {
|
||||
const retryAfter = headers.get('retry-after')
|
||||
if (!retryAfter) return undefined
|
||||
|
||||
// Could be seconds (number) or HTTP date
|
||||
const seconds = parseInt(retryAfter, 10)
|
||||
if (!isNaN(seconds)) {
|
||||
return seconds * 1000
|
||||
}
|
||||
|
||||
// Try parsing as HTTP date
|
||||
const date = new Date(retryAfter)
|
||||
if (!isNaN(date.getTime())) {
|
||||
return Math.max(0, date.getTime() - Date.now())
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
204
packages/llm-client/src/providers/openai.ts
Normal file
204
packages/llm-client/src/providers/openai.ts
Normal file
@@ -0,0 +1,204 @@
|
||||
import type { ProviderConfig, ProviderRequest, ProviderResponse } from '../types'
|
||||
import { LLMApiError, LLMTruncationError, LLMContentFilterError } from '../types'
|
||||
import { BaseProvider } from './base'
|
||||
|
||||
/**
|
||||
* OpenAI message content item
|
||||
*/
|
||||
interface ContentItem {
|
||||
type: 'text' | 'image_url'
|
||||
text?: string
|
||||
image_url?: {
|
||||
url: string
|
||||
detail?: 'auto' | 'low' | 'high'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI chat message
|
||||
*/
|
||||
interface ChatMessage {
|
||||
role: 'system' | 'user' | 'assistant'
|
||||
content: string | ContentItem[]
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI chat completion response
|
||||
*/
|
||||
interface ChatCompletionResponse {
|
||||
id: string
|
||||
choices: Array<{
|
||||
index: number
|
||||
message: {
|
||||
role: string
|
||||
content: string | null
|
||||
refusal?: string | null
|
||||
}
|
||||
finish_reason: 'stop' | 'length' | 'content_filter' | 'tool_calls' | 'function_call' | null
|
||||
}>
|
||||
usage: {
|
||||
prompt_tokens: number
|
||||
completion_tokens: number
|
||||
total_tokens: number
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI error response structure
|
||||
*/
|
||||
interface OpenAIErrorResponse {
|
||||
error?: {
|
||||
message?: string
|
||||
type?: string
|
||||
code?: string
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI provider implementation
|
||||
*
|
||||
* Supports both text and vision models using the chat completions API
|
||||
* with JSON mode for structured output.
|
||||
*/
|
||||
export class OpenAIProvider extends BaseProvider {
|
||||
constructor(config: ProviderConfig) {
|
||||
super(config)
|
||||
}
|
||||
|
||||
async call(request: ProviderRequest): Promise<ProviderResponse> {
|
||||
const prompt = this.buildPrompt(request)
|
||||
const messages = this.buildMessages(prompt, request.images)
|
||||
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model: request.model,
|
||||
messages,
|
||||
response_format: {
|
||||
type: 'json_schema',
|
||||
json_schema: {
|
||||
name: 'response',
|
||||
schema: request.jsonSchema,
|
||||
strict: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.config.baseUrl}/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.config.apiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
let errorMessage = errorText
|
||||
try {
|
||||
const errorJson = JSON.parse(errorText) as OpenAIErrorResponse
|
||||
errorMessage = errorJson.error?.message ?? errorText
|
||||
} catch {
|
||||
// Keep original text
|
||||
}
|
||||
|
||||
// Parse Retry-After header for rate limits
|
||||
const retryAfterMs = this.parseRetryAfter(response.headers)
|
||||
|
||||
throw new LLMApiError(this.name, response.status, errorMessage, retryAfterMs)
|
||||
}
|
||||
|
||||
const data = (await response.json()) as ChatCompletionResponse
|
||||
|
||||
if (!data.choices || data.choices.length === 0) {
|
||||
throw new LLMApiError(this.name, 500, 'No response choices returned')
|
||||
}
|
||||
|
||||
const choice = data.choices[0]
|
||||
|
||||
// Check for content filter refusal
|
||||
if (choice.finish_reason === 'content_filter') {
|
||||
throw new LLMContentFilterError(
|
||||
this.name,
|
||||
choice.message.refusal ?? 'Content was filtered by the model'
|
||||
)
|
||||
}
|
||||
|
||||
// Check for model refusal (new in GPT-4o)
|
||||
if (choice.message.refusal) {
|
||||
throw new LLMContentFilterError(this.name, choice.message.refusal)
|
||||
}
|
||||
|
||||
// Check for truncation due to token limits
|
||||
if (choice.finish_reason === 'length') {
|
||||
// Try to parse whatever we got
|
||||
let partialContent: unknown = null
|
||||
if (choice.message.content) {
|
||||
try {
|
||||
partialContent = this.parseJsonResponse(choice.message.content)
|
||||
} catch {
|
||||
partialContent = choice.message.content
|
||||
}
|
||||
}
|
||||
throw new LLMTruncationError(this.name, partialContent)
|
||||
}
|
||||
|
||||
// Check for null content
|
||||
if (!choice.message.content) {
|
||||
throw new LLMApiError(this.name, 500, 'Empty response content')
|
||||
}
|
||||
|
||||
// Parse JSON response
|
||||
const parsedContent = this.parseJsonResponse(choice.message.content)
|
||||
|
||||
return {
|
||||
content: parsedContent,
|
||||
usage: {
|
||||
promptTokens: data.usage?.prompt_tokens ?? 0,
|
||||
completionTokens: data.usage?.completion_tokens ?? 0,
|
||||
},
|
||||
finishReason: choice.finish_reason ?? 'unknown',
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build chat messages for the request
|
||||
*/
|
||||
private buildMessages(prompt: string, images?: string[]): ChatMessage[] {
|
||||
// If no images, simple text message
|
||||
if (!images || images.length === 0) {
|
||||
return [
|
||||
{
|
||||
role: 'user',
|
||||
content: prompt,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
// Vision request: combine images and text
|
||||
const content: ContentItem[] = []
|
||||
|
||||
// Add images first
|
||||
for (const imageUrl of images) {
|
||||
content.push({
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: imageUrl,
|
||||
detail: 'high',
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Add text prompt
|
||||
content.push({
|
||||
type: 'text',
|
||||
text: prompt,
|
||||
})
|
||||
|
||||
return [
|
||||
{
|
||||
role: 'user',
|
||||
content,
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
606
packages/llm-client/src/retry.test.ts
Normal file
606
packages/llm-client/src/retry.test.ts
Normal file
@@ -0,0 +1,606 @@
|
||||
import { describe, it, expect, vi } from 'vitest'
|
||||
import {
|
||||
executeWithRetry,
|
||||
buildFeedbackPrompt,
|
||||
isRetryableError,
|
||||
getRetryDelay,
|
||||
} from './retry'
|
||||
import {
|
||||
LLMValidationError,
|
||||
LLMApiError,
|
||||
LLMTruncationError,
|
||||
LLMContentFilterError,
|
||||
LLMJsonParseError,
|
||||
} from './types'
|
||||
import type { ValidationFeedback, LLMProgress } from './types'
|
||||
|
||||
describe('retry', () => {
|
||||
describe('isRetryableError', () => {
|
||||
it('should return false for LLMContentFilterError', () => {
|
||||
const error = new LLMContentFilterError('openai', 'Content was filtered')
|
||||
expect(isRetryableError(error)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for LLMTruncationError', () => {
|
||||
const error = new LLMTruncationError('openai', { partial: 'data' })
|
||||
expect(isRetryableError(error)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true for LLMValidationError', () => {
|
||||
const error = new LLMValidationError({ field: 'test', error: 'invalid' })
|
||||
expect(isRetryableError(error)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for LLMJsonParseError', () => {
|
||||
const error = new LLMJsonParseError('not json', 'Unexpected token')
|
||||
expect(isRetryableError(error)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for rate limit errors (429)', () => {
|
||||
const error = new LLMApiError('openai', 429, 'Rate limited')
|
||||
expect(isRetryableError(error)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for server errors (5xx)', () => {
|
||||
const error500 = new LLMApiError('openai', 500, 'Internal error')
|
||||
const error502 = new LLMApiError('openai', 502, 'Bad gateway')
|
||||
const error503 = new LLMApiError('openai', 503, 'Service unavailable')
|
||||
expect(isRetryableError(error500)).toBe(true)
|
||||
expect(isRetryableError(error502)).toBe(true)
|
||||
expect(isRetryableError(error503)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for client errors (4xx except 429)', () => {
|
||||
const error400 = new LLMApiError('openai', 400, 'Bad request')
|
||||
const error401 = new LLMApiError('openai', 401, 'Unauthorized')
|
||||
const error403 = new LLMApiError('openai', 403, 'Forbidden')
|
||||
const error404 = new LLMApiError('openai', 404, 'Not found')
|
||||
expect(isRetryableError(error400)).toBe(false)
|
||||
expect(isRetryableError(error401)).toBe(false)
|
||||
expect(isRetryableError(error403)).toBe(false)
|
||||
expect(isRetryableError(error404)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true for generic errors (network issues)', () => {
|
||||
const error = new Error('Network error')
|
||||
expect(isRetryableError(error)).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getRetryDelay', () => {
|
||||
it('should use Retry-After from LLMApiError when available', () => {
|
||||
const error = new LLMApiError('openai', 429, 'Rate limited', 5000)
|
||||
const delay = getRetryDelay(error, 1, 1000, 60000)
|
||||
expect(delay).toBe(5000)
|
||||
})
|
||||
|
||||
it('should cap Retry-After at maxDelayMs', () => {
|
||||
const error = new LLMApiError('openai', 429, 'Rate limited', 120000)
|
||||
const delay = getRetryDelay(error, 1, 1000, 60000)
|
||||
expect(delay).toBe(60000)
|
||||
})
|
||||
|
||||
it('should use exponential backoff without Retry-After', () => {
|
||||
const delay1 = getRetryDelay(null, 1, 1000, 60000)
|
||||
const delay2 = getRetryDelay(null, 2, 1000, 60000)
|
||||
const delay3 = getRetryDelay(null, 3, 1000, 60000)
|
||||
|
||||
// Should roughly double each time (with jitter)
|
||||
expect(delay1).toBeGreaterThanOrEqual(1000)
|
||||
expect(delay1).toBeLessThan(1200) // 10% jitter max
|
||||
expect(delay2).toBeGreaterThanOrEqual(2000)
|
||||
expect(delay2).toBeLessThan(2400)
|
||||
expect(delay3).toBeGreaterThanOrEqual(4000)
|
||||
expect(delay3).toBeLessThan(4800)
|
||||
})
|
||||
|
||||
it('should cap exponential backoff at maxDelayMs', () => {
|
||||
const delay = getRetryDelay(null, 10, 1000, 5000)
|
||||
expect(delay).toBe(5000)
|
||||
})
|
||||
})
|
||||
|
||||
describe('executeWithRetry', () => {
|
||||
it('should succeed on first attempt when validation passes', async () => {
|
||||
const fn = vi.fn().mockResolvedValue({ value: 42 })
|
||||
const validate = vi.fn().mockReturnValue(null)
|
||||
|
||||
const result = await executeWithRetry(fn, validate, { maxRetries: 2 })
|
||||
|
||||
expect(result.result).toEqual({ value: 42 })
|
||||
expect(result.attempts).toBe(1)
|
||||
expect(fn).toHaveBeenCalledTimes(1)
|
||||
expect(fn).toHaveBeenCalledWith(undefined)
|
||||
expect(validate).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should retry and succeed when validation fails then passes', async () => {
|
||||
const validationError: ValidationFeedback = {
|
||||
field: 'value',
|
||||
error: 'must be positive',
|
||||
received: -1,
|
||||
}
|
||||
|
||||
let callCount = 0
|
||||
const fn = vi.fn().mockImplementation(() => {
|
||||
callCount++
|
||||
return Promise.resolve({ value: callCount === 1 ? -1 : 42 })
|
||||
})
|
||||
|
||||
const validate = vi.fn().mockImplementation((result) => {
|
||||
if (result.value < 0) {
|
||||
return validationError
|
||||
}
|
||||
return null
|
||||
})
|
||||
|
||||
const result = await executeWithRetry(fn, validate, {
|
||||
maxRetries: 2,
|
||||
baseDelayMs: 10,
|
||||
})
|
||||
|
||||
expect(result.result).toEqual({ value: 42 })
|
||||
expect(result.attempts).toBe(2)
|
||||
expect(fn).toHaveBeenCalledTimes(2)
|
||||
expect(fn).toHaveBeenLastCalledWith(validationError)
|
||||
})
|
||||
|
||||
it('should throw LLMValidationError after max retries', async () => {
|
||||
const validationError: ValidationFeedback = {
|
||||
field: 'value',
|
||||
error: 'always invalid',
|
||||
}
|
||||
|
||||
const fn = vi.fn().mockResolvedValue({ value: 'bad' })
|
||||
const validate = vi.fn().mockReturnValue(validationError)
|
||||
|
||||
await expect(
|
||||
executeWithRetry(fn, validate, { maxRetries: 2, baseDelayMs: 10 })
|
||||
).rejects.toThrow(LLMValidationError)
|
||||
|
||||
expect(fn).toHaveBeenCalledTimes(3)
|
||||
})
|
||||
|
||||
it('should include validation error in thrown LLMValidationError', async () => {
|
||||
const validationError: ValidationFeedback = {
|
||||
field: 'items.0.price',
|
||||
error: 'Expected number, received string',
|
||||
received: 'free',
|
||||
expected: 'number',
|
||||
}
|
||||
|
||||
const fn = vi.fn().mockResolvedValue({})
|
||||
const validate = vi.fn().mockReturnValue(validationError)
|
||||
|
||||
try {
|
||||
await executeWithRetry(fn, validate, { maxRetries: 1, baseDelayMs: 10 })
|
||||
expect.fail('Should have thrown')
|
||||
} catch (error) {
|
||||
expect(error).toBeInstanceOf(LLMValidationError)
|
||||
const llmError = error as LLMValidationError
|
||||
expect(llmError.feedback).toEqual(validationError)
|
||||
expect(llmError.message).toContain('items.0.price')
|
||||
}
|
||||
})
|
||||
|
||||
it('should retry on server errors', async () => {
|
||||
let callCount = 0
|
||||
const fn = vi.fn().mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return Promise.reject(new LLMApiError('openai', 500, 'Server error'))
|
||||
}
|
||||
return Promise.resolve({ value: 'success' })
|
||||
})
|
||||
|
||||
const validate = vi.fn().mockReturnValue(null)
|
||||
|
||||
const result = await executeWithRetry(fn, validate, {
|
||||
maxRetries: 2,
|
||||
baseDelayMs: 10,
|
||||
})
|
||||
|
||||
expect(result.result).toEqual({ value: 'success' })
|
||||
expect(result.attempts).toBe(2)
|
||||
})
|
||||
|
||||
it('should retry on rate limit errors', async () => {
|
||||
let callCount = 0
|
||||
const fn = vi.fn().mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return Promise.reject(new LLMApiError('openai', 429, 'Rate limited', 100))
|
||||
}
|
||||
return Promise.resolve({ value: 'success' })
|
||||
})
|
||||
|
||||
const validate = vi.fn().mockReturnValue(null)
|
||||
|
||||
const result = await executeWithRetry(fn, validate, {
|
||||
maxRetries: 2,
|
||||
baseDelayMs: 10,
|
||||
})
|
||||
|
||||
expect(result.result).toEqual({ value: 'success' })
|
||||
expect(result.attempts).toBe(2)
|
||||
})
|
||||
|
||||
it('should NOT retry on content filter errors', async () => {
|
||||
const fn = vi.fn().mockRejectedValue(
|
||||
new LLMContentFilterError('openai', 'Content was filtered')
|
||||
)
|
||||
const validate = vi.fn()
|
||||
|
||||
await expect(
|
||||
executeWithRetry(fn, validate, { maxRetries: 2, baseDelayMs: 10 })
|
||||
).rejects.toThrow(LLMContentFilterError)
|
||||
|
||||
expect(fn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should NOT retry on truncation errors', async () => {
|
||||
const fn = vi.fn().mockRejectedValue(
|
||||
new LLMTruncationError('openai', { partial: 'data' })
|
||||
)
|
||||
const validate = vi.fn()
|
||||
|
||||
await expect(
|
||||
executeWithRetry(fn, validate, { maxRetries: 2, baseDelayMs: 10 })
|
||||
).rejects.toThrow(LLMTruncationError)
|
||||
|
||||
expect(fn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should NOT retry on client errors (4xx)', async () => {
|
||||
const fn = vi.fn().mockRejectedValue(
|
||||
new LLMApiError('openai', 400, 'Bad request')
|
||||
)
|
||||
const validate = vi.fn()
|
||||
|
||||
await expect(
|
||||
executeWithRetry(fn, validate, { maxRetries: 2, baseDelayMs: 10 })
|
||||
).rejects.toThrow(LLMApiError)
|
||||
|
||||
expect(fn).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should retry on JSON parse errors and include feedback', async () => {
|
||||
let callCount = 0
|
||||
const fn = vi.fn().mockImplementation((feedback) => {
|
||||
callCount++
|
||||
if (callCount === 1) {
|
||||
return Promise.reject(new LLMJsonParseError('not json {', 'Unexpected end'))
|
||||
}
|
||||
// On retry, feedback should be set
|
||||
expect(feedback).toBeDefined()
|
||||
expect(feedback?.field).toBe('root')
|
||||
expect(feedback?.error).toContain('not valid JSON')
|
||||
return Promise.resolve({ value: 'success' })
|
||||
})
|
||||
|
||||
const validate = vi.fn().mockReturnValue(null)
|
||||
|
||||
const result = await executeWithRetry(fn, validate, {
|
||||
maxRetries: 2,
|
||||
baseDelayMs: 10,
|
||||
})
|
||||
|
||||
expect(result.result).toEqual({ value: 'success' })
|
||||
expect(result.attempts).toBe(2)
|
||||
})
|
||||
|
||||
it('should throw API error after max retries', async () => {
|
||||
const apiError = new LLMApiError('openai', 500, 'API is down')
|
||||
const fn = vi.fn().mockRejectedValue(apiError)
|
||||
const validate = vi.fn()
|
||||
|
||||
await expect(
|
||||
executeWithRetry(fn, validate, { maxRetries: 2, baseDelayMs: 10 })
|
||||
).rejects.toThrow('API is down')
|
||||
|
||||
expect(fn).toHaveBeenCalledTimes(3)
|
||||
expect(validate).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onProgress at each stage', async () => {
|
||||
const progressCalls: LLMProgress[] = []
|
||||
const onProgress = vi.fn((progress: LLMProgress) => {
|
||||
progressCalls.push({ ...progress })
|
||||
})
|
||||
|
||||
const fn = vi.fn().mockResolvedValue({ value: 42 })
|
||||
const validate = vi.fn().mockReturnValue(null)
|
||||
|
||||
await executeWithRetry(fn, validate, {
|
||||
maxRetries: 2,
|
||||
onProgress,
|
||||
})
|
||||
|
||||
expect(progressCalls).toHaveLength(2)
|
||||
expect(progressCalls[0].stage).toBe('calling')
|
||||
expect(progressCalls[0].attempt).toBe(1)
|
||||
expect(progressCalls[0].maxAttempts).toBe(3)
|
||||
expect(progressCalls[0].message).toBe('Calling LLM...')
|
||||
expect(progressCalls[1].stage).toBe('validating')
|
||||
expect(progressCalls[1].attempt).toBe(1)
|
||||
})
|
||||
|
||||
it('should call onProgress with retry stage on validation failure', async () => {
|
||||
const progressCalls: LLMProgress[] = []
|
||||
const onProgress = vi.fn((progress: LLMProgress) => {
|
||||
progressCalls.push({ ...progress })
|
||||
})
|
||||
|
||||
const validationError: ValidationFeedback = {
|
||||
field: 'name',
|
||||
error: 'too short',
|
||||
}
|
||||
|
||||
let callCount = 0
|
||||
const fn = vi.fn().mockImplementation(() => {
|
||||
callCount++
|
||||
return Promise.resolve({ name: callCount === 1 ? 'a' : 'valid' })
|
||||
})
|
||||
|
||||
const validate = vi.fn().mockImplementation((result) => {
|
||||
if (result.name.length < 2) return validationError
|
||||
return null
|
||||
})
|
||||
|
||||
await executeWithRetry(fn, validate, {
|
||||
maxRetries: 2,
|
||||
baseDelayMs: 10,
|
||||
onProgress,
|
||||
})
|
||||
|
||||
expect(progressCalls).toHaveLength(4)
|
||||
expect(progressCalls[0].stage).toBe('calling')
|
||||
expect(progressCalls[1].stage).toBe('validating')
|
||||
expect(progressCalls[2].stage).toBe('retrying')
|
||||
expect(progressCalls[2].message).toContain('Retry 1/2')
|
||||
expect(progressCalls[2].message).toContain('name')
|
||||
expect(progressCalls[2].validationError).toEqual(validationError)
|
||||
expect(progressCalls[3].stage).toBe('validating')
|
||||
})
|
||||
|
||||
it('should not throw LLMValidationError as API error', async () => {
|
||||
const validationError: ValidationFeedback = {
|
||||
field: 'test',
|
||||
error: 'invalid',
|
||||
}
|
||||
|
||||
const fn = vi.fn().mockResolvedValue({})
|
||||
const validate = vi.fn().mockReturnValue(validationError)
|
||||
|
||||
await expect(
|
||||
executeWithRetry(fn, validate, { maxRetries: 0, baseDelayMs: 10 })
|
||||
).rejects.toThrow(LLMValidationError)
|
||||
})
|
||||
|
||||
it('should respect maxDelayMs option', async () => {
|
||||
const startTime = Date.now()
|
||||
let callCount = 0
|
||||
|
||||
const fn = vi.fn().mockImplementation(() => {
|
||||
callCount++
|
||||
if (callCount < 3) {
|
||||
return Promise.reject(new LLMApiError('openai', 500, 'Error'))
|
||||
}
|
||||
return Promise.resolve({ value: 'success' })
|
||||
})
|
||||
|
||||
const validate = vi.fn().mockReturnValue(null)
|
||||
|
||||
await executeWithRetry(fn, validate, {
|
||||
maxRetries: 3,
|
||||
baseDelayMs: 10,
|
||||
maxDelayMs: 50,
|
||||
})
|
||||
|
||||
const elapsed = Date.now() - startTime
|
||||
// With maxDelayMs of 50ms and 2 retries, total delay should be under 200ms
|
||||
expect(elapsed).toBeLessThan(200)
|
||||
})
|
||||
})
|
||||
|
||||
describe('buildFeedbackPrompt', () => {
|
||||
it('should build basic feedback prompt', () => {
|
||||
const feedback: ValidationFeedback = {
|
||||
field: 'sentiment',
|
||||
error: 'Invalid enum value',
|
||||
}
|
||||
|
||||
const prompt = buildFeedbackPrompt(feedback)
|
||||
|
||||
expect(prompt).toContain('PREVIOUS ATTEMPT HAD VALIDATION ERROR')
|
||||
expect(prompt).toContain('Field: sentiment')
|
||||
expect(prompt).toContain('Error: Invalid enum value')
|
||||
expect(prompt).toContain('Please correct this error')
|
||||
})
|
||||
|
||||
it('should include received value when provided', () => {
|
||||
const feedback: ValidationFeedback = {
|
||||
field: 'count',
|
||||
error: 'Expected number',
|
||||
received: 'five',
|
||||
}
|
||||
|
||||
const prompt = buildFeedbackPrompt(feedback)
|
||||
|
||||
expect(prompt).toContain('Received: "five"')
|
||||
})
|
||||
|
||||
it('should include expected value when provided', () => {
|
||||
const feedback: ValidationFeedback = {
|
||||
field: 'type',
|
||||
error: 'Wrong type',
|
||||
expected: 'number',
|
||||
}
|
||||
|
||||
const prompt = buildFeedbackPrompt(feedback)
|
||||
|
||||
expect(prompt).toContain('Expected: "number"')
|
||||
})
|
||||
|
||||
it('should include valid options when provided', () => {
|
||||
const feedback: ValidationFeedback = {
|
||||
field: 'status',
|
||||
error: 'Invalid enum value',
|
||||
validOptions: ['pending', 'active', 'completed'],
|
||||
}
|
||||
|
||||
const prompt = buildFeedbackPrompt(feedback)
|
||||
|
||||
expect(prompt).toContain('Valid options: pending, active, completed')
|
||||
})
|
||||
|
||||
it('should include all fields when fully specified', () => {
|
||||
const feedback: ValidationFeedback = {
|
||||
field: 'items.0.status',
|
||||
error: 'Invalid status value',
|
||||
received: 'unknown',
|
||||
expected: 'one of the valid options',
|
||||
validOptions: ['draft', 'published', 'archived'],
|
||||
}
|
||||
|
||||
const prompt = buildFeedbackPrompt(feedback)
|
||||
|
||||
expect(prompt).toContain('Field: items.0.status')
|
||||
expect(prompt).toContain('Error: Invalid status value')
|
||||
expect(prompt).toContain('Received: "unknown"')
|
||||
expect(prompt).toContain('Expected: "one of the valid options"')
|
||||
expect(prompt).toContain('Valid options: draft, published, archived')
|
||||
})
|
||||
|
||||
it('should handle empty validOptions array', () => {
|
||||
const feedback: ValidationFeedback = {
|
||||
field: 'test',
|
||||
error: 'error',
|
||||
validOptions: [],
|
||||
}
|
||||
|
||||
const prompt = buildFeedbackPrompt(feedback)
|
||||
|
||||
expect(prompt).not.toContain('Valid options:')
|
||||
})
|
||||
|
||||
it('should handle complex received values', () => {
|
||||
const feedback: ValidationFeedback = {
|
||||
field: 'data',
|
||||
error: 'Invalid structure',
|
||||
received: { nested: { value: [1, 2, 3] } },
|
||||
}
|
||||
|
||||
const prompt = buildFeedbackPrompt(feedback)
|
||||
|
||||
expect(prompt).toContain('Received: {"nested":{"value":[1,2,3]}}')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('LLMApiError', () => {
|
||||
describe('isRateLimited', () => {
|
||||
it('should return true for 429', () => {
|
||||
const error = new LLMApiError('openai', 429, 'Rate limited')
|
||||
expect(error.isRateLimited()).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for other status codes', () => {
|
||||
expect(new LLMApiError('openai', 400, 'Bad').isRateLimited()).toBe(false)
|
||||
expect(new LLMApiError('openai', 500, 'Error').isRateLimited()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isServerError', () => {
|
||||
it('should return true for 5xx errors', () => {
|
||||
expect(new LLMApiError('openai', 500, 'Error').isServerError()).toBe(true)
|
||||
expect(new LLMApiError('openai', 502, 'Error').isServerError()).toBe(true)
|
||||
expect(new LLMApiError('openai', 503, 'Error').isServerError()).toBe(true)
|
||||
expect(new LLMApiError('openai', 599, 'Error').isServerError()).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for non-5xx errors', () => {
|
||||
expect(new LLMApiError('openai', 400, 'Error').isServerError()).toBe(false)
|
||||
expect(new LLMApiError('openai', 429, 'Error').isServerError()).toBe(false)
|
||||
expect(new LLMApiError('openai', 600, 'Error').isServerError()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isClientError', () => {
|
||||
it('should return true for 4xx errors except 429', () => {
|
||||
expect(new LLMApiError('openai', 400, 'Error').isClientError()).toBe(true)
|
||||
expect(new LLMApiError('openai', 401, 'Error').isClientError()).toBe(true)
|
||||
expect(new LLMApiError('openai', 403, 'Error').isClientError()).toBe(true)
|
||||
expect(new LLMApiError('openai', 404, 'Error').isClientError()).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for 429 (rate limit)', () => {
|
||||
expect(new LLMApiError('openai', 429, 'Error').isClientError()).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for non-4xx errors', () => {
|
||||
expect(new LLMApiError('openai', 500, 'Error').isClientError()).toBe(false)
|
||||
expect(new LLMApiError('openai', 200, 'Error').isClientError()).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('retryAfterMs', () => {
|
||||
it('should store retry-after value', () => {
|
||||
const error = new LLMApiError('openai', 429, 'Rate limited', 5000)
|
||||
expect(error.retryAfterMs).toBe(5000)
|
||||
})
|
||||
|
||||
it('should be undefined when not provided', () => {
|
||||
const error = new LLMApiError('openai', 429, 'Rate limited')
|
||||
expect(error.retryAfterMs).toBeUndefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error types', () => {
|
||||
describe('LLMTruncationError', () => {
|
||||
it('should store partial content', () => {
|
||||
const error = new LLMTruncationError('openai', { partial: 'data' })
|
||||
expect(error.partialContent).toEqual({ partial: 'data' })
|
||||
expect(error.provider).toBe('openai')
|
||||
expect(error.message).toContain('truncated')
|
||||
})
|
||||
})
|
||||
|
||||
describe('LLMContentFilterError', () => {
|
||||
it('should store filter reason', () => {
|
||||
const error = new LLMContentFilterError('openai', 'Harmful content detected')
|
||||
expect(error.filterReason).toBe('Harmful content detected')
|
||||
expect(error.provider).toBe('openai')
|
||||
expect(error.message).toContain('content filter')
|
||||
})
|
||||
|
||||
it('should handle missing filter reason', () => {
|
||||
const error = new LLMContentFilterError('openai')
|
||||
expect(error.filterReason).toBeUndefined()
|
||||
expect(error.message).toContain('content filter')
|
||||
})
|
||||
})
|
||||
|
||||
describe('LLMJsonParseError', () => {
|
||||
it('should store raw content and parse error', () => {
|
||||
const error = new LLMJsonParseError('not json {', 'Unexpected token')
|
||||
expect(error.rawContent).toBe('not json {')
|
||||
expect(error.parseError).toBe('Unexpected token')
|
||||
expect(error.message).toContain('Failed to parse')
|
||||
})
|
||||
})
|
||||
|
||||
describe('LLMValidationError', () => {
|
||||
it('should store validation feedback', () => {
|
||||
const feedback: ValidationFeedback = {
|
||||
field: 'test.field',
|
||||
error: 'Invalid value',
|
||||
}
|
||||
const error = new LLMValidationError(feedback)
|
||||
expect(error.feedback).toEqual(feedback)
|
||||
expect(error.message).toContain('test.field')
|
||||
expect(error.message).toContain('Invalid value')
|
||||
})
|
||||
})
|
||||
})
|
||||
225
packages/llm-client/src/retry.ts
Normal file
225
packages/llm-client/src/retry.ts
Normal file
@@ -0,0 +1,225 @@
|
||||
import type { LLMProgress, ValidationFeedback } from './types'
|
||||
import {
|
||||
LLMValidationError,
|
||||
LLMApiError,
|
||||
LLMTruncationError,
|
||||
LLMContentFilterError,
|
||||
LLMJsonParseError,
|
||||
} from './types'
|
||||
|
||||
/**
|
||||
* Options for retry execution
|
||||
*/
|
||||
export interface RetryOptions {
|
||||
/** Maximum number of retry attempts */
|
||||
maxRetries: number
|
||||
/** Progress callback */
|
||||
onProgress?: (progress: LLMProgress) => void
|
||||
/** Base delay for exponential backoff (ms) */
|
||||
baseDelayMs?: number
|
||||
/** Maximum delay cap (ms) */
|
||||
maxDelayMs?: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if an error is retryable
|
||||
*/
|
||||
export function isRetryableError(error: unknown): boolean {
|
||||
// Content filter errors are never retryable - the model refused
|
||||
if (error instanceof LLMContentFilterError) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Truncation errors might be retryable with a shorter prompt, but not automatically
|
||||
if (error instanceof LLMTruncationError) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validation errors are retryable - we feed back the error to the LLM
|
||||
if (error instanceof LLMValidationError) {
|
||||
return true
|
||||
}
|
||||
|
||||
// JSON parse errors are retryable - LLM might return valid JSON next time
|
||||
if (error instanceof LLMJsonParseError) {
|
||||
return true
|
||||
}
|
||||
|
||||
// API errors: rate limits and server errors are retryable, client errors are not
|
||||
if (error instanceof LLMApiError) {
|
||||
return error.isRateLimited() || error.isServerError()
|
||||
}
|
||||
|
||||
// Generic errors (network issues, etc.) are retryable
|
||||
return true
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the delay before retrying an error
|
||||
*/
|
||||
export function getRetryDelay(
|
||||
error: unknown,
|
||||
attempt: number,
|
||||
baseDelayMs: number,
|
||||
maxDelayMs: number
|
||||
): number {
|
||||
// If it's a rate limit with Retry-After, use that
|
||||
if (error instanceof LLMApiError && error.retryAfterMs) {
|
||||
return Math.min(error.retryAfterMs, maxDelayMs)
|
||||
}
|
||||
|
||||
// Exponential backoff with jitter
|
||||
const exponentialDelay = baseDelayMs * Math.pow(2, attempt - 1)
|
||||
const jitter = Math.random() * 0.1 * exponentialDelay // 10% jitter
|
||||
return Math.min(exponentialDelay + jitter, maxDelayMs)
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a function with retry logic and validation feedback
|
||||
*
|
||||
* On validation failure, the function is called again with the validation
|
||||
* error included, allowing the LLM to correct its response.
|
||||
*
|
||||
* Error handling:
|
||||
* - LLMValidationError: Retried with feedback in prompt
|
||||
* - LLMJsonParseError: Retried (LLM may return valid JSON)
|
||||
* - LLMApiError (429): Retried with Retry-After delay
|
||||
* - LLMApiError (5xx): Retried with exponential backoff
|
||||
* - LLMApiError (4xx): NOT retried (client error)
|
||||
* - LLMContentFilterError: NOT retried (model refused)
|
||||
* - LLMTruncationError: NOT retried (need shorter prompt)
|
||||
*
|
||||
* @param fn - Function to execute, receives validation feedback on retry
|
||||
* @param validate - Validation function, returns error or null if valid
|
||||
* @param options - Retry options
|
||||
*/
|
||||
export async function executeWithRetry<T>(
|
||||
fn: (feedback?: ValidationFeedback) => Promise<T>,
|
||||
validate: (result: T) => ValidationFeedback | null,
|
||||
options: RetryOptions
|
||||
): Promise<{ result: T; attempts: number }> {
|
||||
const {
|
||||
maxRetries,
|
||||
onProgress,
|
||||
baseDelayMs = 1000,
|
||||
maxDelayMs = 60000,
|
||||
} = options
|
||||
const maxAttempts = maxRetries + 1
|
||||
|
||||
let lastFeedback: ValidationFeedback | undefined
|
||||
let lastError: Error | undefined
|
||||
|
||||
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
|
||||
// Report progress
|
||||
onProgress?.({
|
||||
stage: attempt === 1 ? 'calling' : 'retrying',
|
||||
attempt,
|
||||
maxAttempts,
|
||||
message:
|
||||
attempt === 1
|
||||
? 'Calling LLM...'
|
||||
: `Retry ${attempt - 1}/${maxRetries}: fixing ${lastFeedback?.field ?? 'error'}`,
|
||||
validationError: lastFeedback,
|
||||
})
|
||||
|
||||
try {
|
||||
// Call the function (with feedback on retry)
|
||||
const result = await fn(lastFeedback)
|
||||
|
||||
// Report validation stage
|
||||
onProgress?.({
|
||||
stage: 'validating',
|
||||
attempt,
|
||||
maxAttempts,
|
||||
message: 'Validating response...',
|
||||
})
|
||||
|
||||
// Validate the result
|
||||
const error = validate(result)
|
||||
|
||||
if (!error) {
|
||||
// Success!
|
||||
return { result, attempts: attempt }
|
||||
}
|
||||
|
||||
// Validation failed
|
||||
if (attempt < maxAttempts) {
|
||||
// Store feedback for next attempt
|
||||
lastFeedback = error
|
||||
|
||||
// Backoff before retry
|
||||
const delayMs = getRetryDelay(null, attempt, baseDelayMs, maxDelayMs)
|
||||
await sleep(delayMs)
|
||||
} else {
|
||||
// Out of retries
|
||||
throw new LLMValidationError(error)
|
||||
}
|
||||
} catch (error) {
|
||||
// Re-throw validation errors from final attempt
|
||||
if (error instanceof LLMValidationError) {
|
||||
throw error
|
||||
}
|
||||
|
||||
// Check if this error is retryable
|
||||
if (!isRetryableError(error)) {
|
||||
throw error
|
||||
}
|
||||
|
||||
// Store for potential re-throw
|
||||
lastError = error as Error
|
||||
|
||||
if (attempt < maxAttempts) {
|
||||
// Calculate delay based on error type
|
||||
const delayMs = getRetryDelay(error, attempt, baseDelayMs, maxDelayMs)
|
||||
|
||||
// For JSON parse errors, convert to validation feedback for next attempt
|
||||
if (error instanceof LLMJsonParseError) {
|
||||
lastFeedback = {
|
||||
field: 'root',
|
||||
error: 'Response was not valid JSON',
|
||||
received: error.rawContent.substring(0, 500),
|
||||
}
|
||||
}
|
||||
|
||||
await sleep(delayMs)
|
||||
} else {
|
||||
throw lastError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Should never reach here
|
||||
throw lastError ?? new Error('Retry logic failed unexpectedly')
|
||||
}
|
||||
|
||||
/**
|
||||
* Sleep for a specified duration
|
||||
*/
|
||||
function sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms))
|
||||
}
|
||||
|
||||
/**
|
||||
* Build validation feedback message for inclusion in retry prompt
|
||||
*/
|
||||
export function buildFeedbackPrompt(feedback: ValidationFeedback): string {
|
||||
let prompt = '\n\nPREVIOUS ATTEMPT HAD VALIDATION ERROR:\n'
|
||||
prompt += `Field: ${feedback.field}\n`
|
||||
prompt += `Error: ${feedback.error}\n`
|
||||
|
||||
if (feedback.received !== undefined) {
|
||||
prompt += `Received: ${JSON.stringify(feedback.received)}\n`
|
||||
}
|
||||
|
||||
if (feedback.expected !== undefined) {
|
||||
prompt += `Expected: ${JSON.stringify(feedback.expected)}\n`
|
||||
}
|
||||
|
||||
if (feedback.validOptions && feedback.validOptions.length > 0) {
|
||||
prompt += `Valid options: ${feedback.validOptions.join(', ')}\n`
|
||||
}
|
||||
|
||||
prompt += '\nPlease correct this error and provide a valid response.'
|
||||
|
||||
return prompt
|
||||
}
|
||||
233
packages/llm-client/src/types.ts
Normal file
233
packages/llm-client/src/types.ts
Normal file
@@ -0,0 +1,233 @@
|
||||
import type { z } from 'zod'
|
||||
|
||||
/**
|
||||
* Provider configuration loaded from environment variables
|
||||
*/
|
||||
export interface ProviderConfig {
|
||||
/** Provider name (e.g., 'openai', 'anthropic') */
|
||||
name: string
|
||||
/** API key for authentication */
|
||||
apiKey: string
|
||||
/** Base URL for API requests */
|
||||
baseUrl: string
|
||||
/** Default model for this provider */
|
||||
defaultModel: string
|
||||
/** Provider-specific options */
|
||||
options?: Record<string, unknown>
|
||||
}
|
||||
|
||||
/**
|
||||
* LLM client configuration
|
||||
*/
|
||||
export interface LLMClientConfig {
|
||||
/** Default provider to use */
|
||||
defaultProvider: string
|
||||
/** Default model (overrides provider default) */
|
||||
defaultModel?: string
|
||||
/** Configured providers */
|
||||
providers: Record<string, ProviderConfig>
|
||||
/** Default maximum retry attempts */
|
||||
defaultMaxRetries: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Request to make an LLM call with type-safe schema validation
|
||||
*/
|
||||
export interface LLMRequest<T extends z.ZodType> {
|
||||
/** The prompt to send to the LLM */
|
||||
prompt: string
|
||||
/** Base64 data URLs for vision requests */
|
||||
images?: string[]
|
||||
/** Zod schema for response validation */
|
||||
schema: T
|
||||
/** Override default provider */
|
||||
provider?: string
|
||||
/** Override default model */
|
||||
model?: string
|
||||
/** Maximum retry attempts (default: 2) */
|
||||
maxRetries?: number
|
||||
/** Progress callback for UI feedback */
|
||||
onProgress?: (progress: LLMProgress) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Progress updates during LLM call
|
||||
*/
|
||||
export interface LLMProgress {
|
||||
/** Current stage of the call */
|
||||
stage: 'preparing' | 'calling' | 'validating' | 'retrying'
|
||||
/** Current attempt number (1-indexed) */
|
||||
attempt: number
|
||||
/** Maximum number of attempts */
|
||||
maxAttempts: number
|
||||
/** Human-readable status message */
|
||||
message: string
|
||||
/** Validation error from previous attempt (for retries) */
|
||||
validationError?: ValidationFeedback
|
||||
}
|
||||
|
||||
/**
|
||||
* Validation error feedback for retry prompts
|
||||
*/
|
||||
export interface ValidationFeedback {
|
||||
/** Field path that failed validation */
|
||||
field: string
|
||||
/** Error description */
|
||||
error: string
|
||||
/** Value that was received */
|
||||
received?: unknown
|
||||
/** Expected value or type */
|
||||
expected?: unknown
|
||||
/** Valid options (for enum fields) */
|
||||
validOptions?: string[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Response from an LLM call
|
||||
*/
|
||||
export interface LLMResponse<T> {
|
||||
/** Validated response data (typed according to schema) */
|
||||
data: T
|
||||
/** Token usage statistics */
|
||||
usage: {
|
||||
promptTokens: number
|
||||
completionTokens: number
|
||||
totalTokens: number
|
||||
}
|
||||
/** Number of attempts needed */
|
||||
attempts: number
|
||||
/** Provider that was used */
|
||||
provider: string
|
||||
/** Model that was used */
|
||||
model: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal request passed to providers
|
||||
*/
|
||||
export interface ProviderRequest {
|
||||
/** The prompt to send */
|
||||
prompt: string
|
||||
/** Base64 data URLs for vision */
|
||||
images?: string[]
|
||||
/** JSON schema for structured output */
|
||||
jsonSchema: Record<string, unknown>
|
||||
/** Model to use */
|
||||
model: string
|
||||
/** Validation feedback from previous attempt */
|
||||
validationFeedback?: ValidationFeedback
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal response from providers
|
||||
*/
|
||||
export interface ProviderResponse {
|
||||
/** Raw content from the LLM (JSON string or parsed object) */
|
||||
content: unknown
|
||||
/** Token usage */
|
||||
usage: {
|
||||
promptTokens: number
|
||||
completionTokens: number
|
||||
}
|
||||
/** Finish reason */
|
||||
finishReason: string
|
||||
}
|
||||
|
||||
/**
|
||||
* LLM Provider interface for implementing different providers
|
||||
*/
|
||||
export interface LLMProvider {
|
||||
/** Provider name */
|
||||
readonly name: string
|
||||
/** Make an LLM call */
|
||||
call(request: ProviderRequest): Promise<ProviderResponse>
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when LLM validation fails after all retries
|
||||
*/
|
||||
export class LLMValidationError extends Error {
|
||||
constructor(public readonly feedback: ValidationFeedback) {
|
||||
super(`LLM validation failed: ${feedback.field} - ${feedback.error}`)
|
||||
this.name = 'LLMValidationError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when provider is not configured
|
||||
*/
|
||||
export class ProviderNotConfiguredError extends Error {
|
||||
constructor(provider: string) {
|
||||
super(`Provider '${provider}' is not configured. Check your environment variables.`)
|
||||
this.name = 'ProviderNotConfiguredError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when LLM API call fails
|
||||
*/
|
||||
export class LLMApiError extends Error {
|
||||
constructor(
|
||||
public readonly provider: string,
|
||||
public readonly statusCode: number,
|
||||
message: string,
|
||||
public readonly retryAfterMs?: number
|
||||
) {
|
||||
super(`${provider} API error (${statusCode}): ${message}`)
|
||||
this.name = 'LLMApiError'
|
||||
}
|
||||
|
||||
/** Check if this is a rate limit error */
|
||||
isRateLimited(): boolean {
|
||||
return this.statusCode === 429
|
||||
}
|
||||
|
||||
/** Check if this is a server error that may be transient */
|
||||
isServerError(): boolean {
|
||||
return this.statusCode >= 500 && this.statusCode < 600
|
||||
}
|
||||
|
||||
/** Check if this is a client error that won't be fixed by retrying */
|
||||
isClientError(): boolean {
|
||||
return this.statusCode >= 400 && this.statusCode < 500 && this.statusCode !== 429
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when LLM response is truncated due to token limits
|
||||
*/
|
||||
export class LLMTruncationError extends Error {
|
||||
constructor(
|
||||
public readonly provider: string,
|
||||
public readonly partialContent: unknown
|
||||
) {
|
||||
super(`${provider} response was truncated due to token limits. Partial content received.`)
|
||||
this.name = 'LLMTruncationError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when LLM refuses to respond due to content filter
|
||||
*/
|
||||
export class LLMContentFilterError extends Error {
|
||||
constructor(
|
||||
public readonly provider: string,
|
||||
public readonly filterReason?: string
|
||||
) {
|
||||
super(`${provider} refused to respond due to content filter${filterReason ? `: ${filterReason}` : ''}`)
|
||||
this.name = 'LLMContentFilterError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when JSON parsing fails
|
||||
*/
|
||||
export class LLMJsonParseError extends Error {
|
||||
constructor(
|
||||
public readonly rawContent: string,
|
||||
public readonly parseError: string
|
||||
) {
|
||||
super(`Failed to parse LLM JSON response: ${parseError}`)
|
||||
this.name = 'LLMJsonParseError'
|
||||
}
|
||||
}
|
||||
21
packages/llm-client/tsconfig.json
Normal file
21
packages/llm-client/tsconfig.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"lib": ["ES2020"],
|
||||
"module": "ES2020",
|
||||
"moduleResolution": "node",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src",
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"types": ["node"]
|
||||
},
|
||||
"include": ["src/**/*"],
|
||||
"exclude": ["node_modules", "dist", "**/*.test.*"]
|
||||
}
|
||||
9
packages/llm-client/vitest.config.ts
Normal file
9
packages/llm-client/vitest.config.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
import { defineConfig } from 'vitest/config'
|
||||
|
||||
export default defineConfig({
|
||||
test: {
|
||||
globals: true,
|
||||
environment: 'node',
|
||||
include: ['src/**/*.test.ts'],
|
||||
},
|
||||
})
|
||||
22779
pnpm-lock.yaml
generated
22779
pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@ packages:
|
||||
- "apps/*"
|
||||
- "packages/templates"
|
||||
- "packages/abacus-react"
|
||||
- "packages/llm-client"
|
||||
- "packages/core/client/node"
|
||||
- "packages/core/client/typescript"
|
||||
- "packages/core/client/browser"
|
||||
|
||||
3
public/models/abacus-column-classifier/README.json
Normal file
3
public/models/abacus-column-classifier/README.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"note": "Placeholder - model needs to be trained. Run: npx tsx scripts/train-column-classifier/generateTrainingData.ts and then the Python training script."
|
||||
}
|
||||
Reference in New Issue
Block a user