Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions src/app/ml-training-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,29 @@ export class MlTrainingClient {
this.client = createPromiseClient(MLTrainingService, transport);
}

/**
* Submit a training job.
*
* @example
*
* ```ts
* await mlTrainingClient.submitTrainingJob(
* '<organization-id>',
* '<dataset-id>',
* '<your-model-name>',
* '1.0.0',
* ModelType.SINGLE_LABEL_CLASSIFICATION,
* ['tag1', 'tag2']
* );
* ```
*
* @param organizationId - The organization ID.
* @param datasetId - The dataset ID.
* @param modelName - The model name.
* @param modelVersion - The model version.
* @param modelType - The model type.
* @param tags - The tags.
*/
async submitTrainingJob(
organizationId: string,
datasetId: string,
Expand All @@ -36,6 +59,29 @@ export class MlTrainingClient {
return resp.id;
}

/**
* Submit a training job from a custom training script.
*
* @example
*
* ```ts
* await mlTrainingClient.submitCustomTrainingJob(
* '<organization-id>',
* '<dataset-id>',
* 'viam:classification-tflite',
* '1.0.0',
* '<your-model-name>',
* '1.0.0'
* );
* ```
*
* @param organizationId - The organization ID.
* @param datasetId - The dataset ID.
* @param registryItemId - The registry item ID.
* @param registryItemVersion - The registry item version.
* @param modelName - The model name.
* @param modelVersion - The model version.
*/
async submitCustomTrainingJob(
organizationId: string,
datasetId: string,
Expand All @@ -55,21 +101,69 @@ export class MlTrainingClient {
return resp.id;
}

/**
* Get a training job metadata.
*
* @example
*
* ```ts
* const job = await mlTrainingClient.getTrainingJob('<training-job-id>');
* ```
*
* @param id - The training job ID.
*/
async getTrainingJob(id: string) {
const resp = await this.client.getTrainingJob({ id });
return resp.metadata;
}

/**
* List training jobs.
*
* @example
*
* ```ts
* const jobs = await mlTrainingClient.listTrainingJobs(
* '<organization-id>',
* TrainingStatus.RUNNING
* );
* ```
*
* @param organizationId - The organization ID.
* @param status - The training job status.
*/
async listTrainingJobs(organizationId: string, status: TrainingStatus) {
const resp = await this.client.listTrainingJobs({ organizationId, status });
return resp.jobs;
}

/**
* Cancel a training job.
*
* @example
*
* ```ts
* await mlTrainingClient.cancelTrainingJob('<training-job-id>');
* ```
*
* @param id - The training job ID.
*/
async cancelTrainingJob(id: string) {
await this.client.cancelTrainingJob({ id });
return null;
}

/**
* Delete a completed training job.
*
* @example
*
* ```ts
* await mlTrainingClient.deleteCompletedTrainingJob('<training-job-id>');
* ```
*
* @param id - The training job ID.
*/
async deleteCompletedTrainingJob(id: string) {
await this.client.deleteCompletedTrainingJob({ id });
return null;
Expand Down