diff --git a/src/app/ml-training-client.ts b/src/app/ml-training-client.ts index 422946734..f2f7fcf93 100644 --- a/src/app/ml-training-client.ts +++ b/src/app/ml-training-client.ts @@ -17,6 +17,29 @@ export class MlTrainingClient { this.client = createPromiseClient(MLTrainingService, transport); } + /** + * Submit a training job. + * + * @example + * + * ```ts + * await mlTrainingClient.submitTrainingJob( + * '', + * '', + * '', + * '1.0.0', + * ModelType.SINGLE_LABEL_CLASSIFICATION, + * ['tag1', 'tag2'] + * ); + * ``` + * + * @param organizationId - The organization ID. + * @param datasetId - The dataset ID. + * @param modelName - The model name. + * @param modelVersion - The model version. + * @param modelType - The model type. + * @param tags - The tags. + */ async submitTrainingJob( organizationId: string, datasetId: string, @@ -36,6 +59,29 @@ export class MlTrainingClient { return resp.id; } + /** + * Submit a training job from a custom training script. + * + * @example + * + * ```ts + * await mlTrainingClient.submitCustomTrainingJob( + * '', + * '', + * 'viam:classification-tflite', + * '1.0.0', + * '', + * '1.0.0' + * ); + * ``` + * + * @param organizationId - The organization ID. + * @param datasetId - The dataset ID. + * @param registryItemId - The registry item ID. + * @param registryItemVersion - The registry item version. + * @param modelName - The model name. + * @param modelVersion - The model version. + */ async submitCustomTrainingJob( organizationId: string, datasetId: string, @@ -55,21 +101,69 @@ export class MlTrainingClient { return resp.id; } + /** + * Get a training job metadata. + * + * @example + * + * ```ts + * const job = await mlTrainingClient.getTrainingJob(''); + * ``` + * + * @param id - The training job ID. + */ async getTrainingJob(id: string) { const resp = await this.client.getTrainingJob({ id }); return resp.metadata; } + /** + * List training jobs. + * + * @example + * + * ```ts + * const jobs = await mlTrainingClient.listTrainingJobs( + * '', + * TrainingStatus.RUNNING + * ); + * ``` + * + * @param organizationId - The organization ID. + * @param status - The training job status. + */ async listTrainingJobs(organizationId: string, status: TrainingStatus) { const resp = await this.client.listTrainingJobs({ organizationId, status }); return resp.jobs; } + /** + * Cancel a training job. + * + * @example + * + * ```ts + * await mlTrainingClient.cancelTrainingJob(''); + * ``` + * + * @param id - The training job ID. + */ async cancelTrainingJob(id: string) { await this.client.cancelTrainingJob({ id }); return null; } + /** + * Delete a completed training job. + * + * @example + * + * ```ts + * await mlTrainingClient.deleteCompletedTrainingJob(''); + * ``` + * + * @param id - The training job ID. + */ async deleteCompletedTrainingJob(id: string) { await this.client.deleteCompletedTrainingJob({ id }); return null;