MPSNNGraph: use custom compute/render metal during training?

Hello,

I have been following the excellent/informative "Metal for Machine Learning" from WWDC19 to learn how to do on device training (I have a specific use case for this) and it is all working really well using the MPSNNGraph.

However, I would like to call my own metal compute/render function/pipeline to transform the inference result before calculating the loss, does anyone know if this possible and what would this look like in code?

Please see my current code below, at the comment I need to call an intermediate compute/render function to transform the inference result image before passing to the MPSNNForwardLossNode.

let rgbImageNode = MPSNNImageNode(handle: nil)
        
let inferGraph = makeInferenceGraph()
      
let reshape = MPSNNReshapeNode(source: inferGraph.resultImage, resultWidth: 64, resultHeight: 64, resultFeatureChannels: 4)
        
 //Need to call render or compute pipeline to post process in the inference result image
        
let rgbLoss = MPSNNForwardLossNode(source:reshape.resultImage, labels:rgbImageNode, lossDescriptor:lossDescriptor)
        
let initGrad = MPSNNInitialGradientNode(source:rgbLoss.resultImage)
        
let gradNodes = initGrad.trainingGraph(withSourceGradient:nil, nodeHandler:nil)
       
guard let trainGraph = MPSNNGraph(device: device, resultImage: gradNodes![0].resultImage, resultImageIsNeeded: true) else{
            fatalError("Unable to get training graph.")

}

Thanks