-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathMLRecognizer.swift
More file actions
97 lines (81 loc) · 2.74 KB
/
MLRecognizer.swift
File metadata and controls
97 lines (81 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
//
// MLRecognizer.swift
// ARKitCoreML
//
// Created by Jason Clark on 10/11/18.
// Copyright © 2018 Raizlabs. All rights reserved.
//
import ARKit
import UIKit
import SceneKit
import Vision
class MLRecognizer: NSObject {
private var model: MLModel
private weak var sceneView: ARSCNView?
private let visionQueue = DispatchQueue(label: "com.raizlabs.serialVisionQueue")
init(model: MLModel, sceneView: ARSCNView) {
self.model = model
self.sceneView = sceneView
super.init()
}
func classify(imageAnchor: ARImageAnchor, completion: @escaping (Result<String>) -> Void) {
DispatchQueue.main.async { [weak self] in
self?._classify(imageAnchor: imageAnchor, completion: { result in
DispatchQueue.main.async {
completion(result)
}
})
}
}
}
private extension MLRecognizer {
private func _classify(imageAnchor: ARImageAnchor, completion: @escaping (Result<String>) -> Void) {
// 1. Crop image of the projection of the anchor
guard
let cropped = sceneView?.capturedImage(from: imageAnchor),
let image = cropped.getOrCreateCGImage()
else {
completion(.failure(MLError.cropFailed))
return
}
// 2. Prepare classification result handler
let classificationResultHandler: VNRequestCompletionHandler = { request, error in
guard let results = request.results as? [VNClassificationObservation] else { return }
// classifications are ordered by confidence
if let mostLikelyClassification = results.first {
completion(.success(mostLikelyClassification.identifier))
}
else {
completion(.failure(MLError.classificationFailed))
}
}
// 3. Dispatch Vision request
let requestHandler = VNImageRequestHandler(cgImage: image, options: [:])
guard let vnModel = try? VNCoreMLModel(for: model)
else {
completion(.failure(MLError.loadModelFailed))
return
}
let request = VNCoreMLRequest(
model: vnModel,
completionHandler: classificationResultHandler
)
request.imageCropAndScaleOption = .centerCrop
request.usesCPUOnly = true
visionQueue.async {
do { try requestHandler.perform([request]) }
catch { completion(.failure(error)) }
}
}
}
extension MLRecognizer {
public enum Result<Value> {
case success(Value)
case failure(Error)
}
public enum MLError: Error {
case cropFailed
case loadModelFailed
case classificationFailed
}
}