This repository was archived by the owner on Jul 1, 2023. It is now read-only.
This repository was archived by the owner on Jul 1, 2023. It is now read-only.
Multiple return parameters #67
Open
Description
I'm trying to implement a simple/limited version of the meshgrid
op. This is what I've got:
func meshgrid(x: Tensor<Float>, y: Tensor<Float>) -> (Tensor<Float>, Tensor<Float>) {
let outputX = x.reshaped(to: [-1, 1])
let outputY = y.reshaped(to: [-1, 1])
let multFactX = Tensor<Float>(ones: [x.scalarCountTensor.scalarized()])
let multFactY = Tensor<Float>(ones: [y.scalarCountTensor.scalarized()])
return ((outputX * multFactX).transposed(), outputY * multFactY)
}
Of course, that can't be made differentiable because tuples aren't differentiable. So, I had to implement this:
struct TensorPair<T: TensorFlowFloatingPoint>: Differentiable {
var first: Tensor<T>
var second: Tensor<T>
@differentiable
init(_ first: Tensor<T>, _ second: Tensor<T>) {
self.first = first
self.second = second
}
}
@differentiable
func meshgrid(x: Tensor<Float>, y: Tensor<Float>) -> TensorPair<Float> {
let outputX = x.reshaped(to: [-1, 1])
let outputY = y.reshaped(to: [-1, 1])
let multFactX = Tensor<Float>(ones: [x.scalarCountTensor.scalarized()])
let multFactY = Tensor<Float>(ones: [y.scalarCountTensor.scalarized()])
return TensorPair((outputX * multFactX).transposed(), outputY * multFactY)
}
It works, but it's not a very elegant solution. Thoughts:
- Will tuples ever be differentiable?
- This could be made more a bit more elegant by returning an array of tensors with the two values.