siq.get_data

   1from pathlib import Path
   2from pathlib import PurePath
   3import os
   4import pandas as pd
   5import math
   6import os.path
   7from os import path
   8from os.path import exists
   9import pickle
  10import sys
  11import numpy as np
  12import random
  13import functools
  14from operator import mul
  15from scipy.sparse.linalg import svds
  16from scipy.stats import pearsonr
  17import re
  18
  19import ants
  20import antspynet
  21import antspyt1w
  22import tensorflow as tf
  23from tensorflow.python.eager.context import eager_mode, graph_mode
  24
  25from multiprocessing import Pool
  26
  27DATA_PATH = os.path.expanduser('~/.siq/')
  28
  29def get_data( name=None, force_download=False, version=0, target_extension='.csv' ):
  30    """
  31    Get SIQ data filename
  32
  33    The first time this is called, it will download data to ~/.siq.
  34    After, it will just read data from disk.  The ~/.siq may need to
  35    be periodically deleted in order to ensure data is current.
  36
  37    Arguments
  38    ---------
  39    name : string
  40        name of data tag to retrieve
  41        Options:
  42            - 'all'
  43
  44    force_download: boolean
  45
  46    version: version of data to download (integer)
  47
  48    Returns
  49    -------
  50    string
  51        filepath of selected data
  52
  53    Example
  54    -------
  55    >>> import siq
  56    >>> siq.get_data()
  57    """
  58    os.makedirs(DATA_PATH, exist_ok=True)
  59
  60    def download_data( version ):
  61        url = "https://figshare.com/ndownloader/articles/16912366/versions/" + str(version)
  62        target_file_name = "16912366.zip"
  63        target_file_name_path = tf.keras.utils.get_file(target_file_name, url,
  64            cache_subdir=DATA_PATH, extract = True )
  65        os.remove( DATA_PATH + target_file_name )
  66
  67    if force_download:
  68        download_data( version = version )
  69
  70
  71    files = []
  72    for fname in os.listdir(DATA_PATH):
  73        if ( fname.endswith(target_extension) ) :
  74            fname = os.path.join(DATA_PATH, fname)
  75            files.append(fname)
  76
  77    if len( files ) == 0 :
  78        download_data( version = version )
  79        for fname in os.listdir(DATA_PATH):
  80            if ( fname.endswith(target_extension) ) :
  81                fname = os.path.join(DATA_PATH, fname)
  82                files.append(fname)
  83
  84    if name == 'all':
  85        return files
  86
  87    datapath = None
  88
  89    for fname in os.listdir(DATA_PATH):
  90        mystem = (Path(fname).resolve().stem)
  91        mystem = (Path(mystem).resolve().stem)
  92        mystem = (Path(mystem).resolve().stem)
  93        if ( name == mystem and fname.endswith(target_extension) ) :
  94            datapath = os.path.join(DATA_PATH, fname)
  95
  96    return datapath
  97
  98
  99
 100from keras.models import Model
 101from keras.layers import (Input, Add, Subtract,
 102                          PReLU, Concatenate,
 103                          UpSampling2D, UpSampling3D,
 104                          Conv2D, Conv2DTranspose,
 105                          Conv3D, Conv3DTranspose)
 106
 107
 108# define the DBPN network - this uses a model definition that is general to<br>
 109# both 2D and 3D. recommended parameters for different upsampling rates can<br>
 110# be found in the papers by Haris et al.  We make one significant change to<br>
 111# the original architecture by allowing standard interpolation for upsampling<br>
 112# instead of convolutional upsampling.  this is controlled by the interpolation<br>
 113# option.
 114
 115
 116
 117def dbpn(input_image_size,
 118                                                 number_of_outputs=1,
 119                                                 number_of_base_filters=64,
 120                                                 number_of_feature_filters=256,
 121                                                 number_of_back_projection_stages=7,
 122                                                 convolution_kernel_size=(12, 12),
 123                                                 strides=(8, 8),
 124                                                 last_convolution=(3, 3),
 125                                                 number_of_loss_functions=1,
 126                                                 interpolation = 'nearest'
 127                                                ):
 128    """
 129    Creates a Deep Back-Projection Network (DBPN) for single image super-resolution.
 130
 131    This function constructs a Keras model based on the DBPN architecture, which
 132    can be configured for either 2D or 3D inputs. The network uses iterative
 133    up- and down-projection blocks to refine the high-resolution image estimate. A
 134    key modification from the original paper is the option to use standard
 135    interpolation for upsampling instead of deconvolution layers.
 136
 137    Reference:
 138     - Haris, M., Shakhnarovich, G., & Ukita, N. (2018). Deep Back-Projection
 139       Networks For Super-Resolution. In CVPR.
 140
 141    Parameters
 142    ----------
 143    input_image_size : tuple or list
 144        The shape of the input image, including the channel.
 145        e.g., `(None, None, 1)` for 2D or `(None, None, None, 1)` for 3D.
 146
 147    number_of_outputs : int, optional
 148        The number of channels in the output image. Default is 1.
 149
 150    number_of_base_filters : int, optional
 151        The number of filters in the up/down projection blocks. Default is 64.
 152
 153    number_of_feature_filters : int, optional
 154        The number of filters in the initial feature extraction layer. Default is 256.
 155
 156    number_of_back_projection_stages : int, optional
 157        The number of iterative back-projection stages (T in the paper). Default is 7.
 158
 159    convolution_kernel_size : tuple or list, optional
 160        The kernel size for the main projection convolutions. Should match the
 161        dimensionality of the input. Default is (12, 12).
 162
 163    strides : tuple or list, optional
 164        The strides for the up/down sampling operations, defining the
 165        super-resolution factor. Default is (8, 8).
 166
 167    last_convolution : tuple or list, optional
 168        The kernel size of the final reconstruction convolution. Default is (3, 3).
 169
 170    number_of_loss_functions : int, optional
 171        If greater than 1, the model will have multiple identical output branches.
 172        Typically set to 1. Default is 1.
 173
 174    interpolation : str, optional
 175        The interpolation method to use for upsampling layers if not using
 176        transposed convolutions. 'nearest' or 'bilinear'. Default is 'nearest'.
 177
 178    Returns
 179    -------
 180    keras.Model
 181        A Keras model implementing the DBPN architecture for the specified
 182        parameters.
 183    """
 184    idim = len( input_image_size ) - 1
 185    if idim == 2:
 186        myconv = Conv2D
 187        myconv_transpose = Conv2DTranspose
 188        myupsampling = UpSampling2D
 189        shax = ( 1, 2 )
 190        firstConv = (3,3)
 191        firstStrides=(1,1)
 192        smashConv=(1,1)
 193    if idim == 3:
 194        myconv = Conv3D
 195        myconv_transpose = Conv3DTranspose
 196        myupsampling = UpSampling3D
 197        shax = ( 1, 2, 3 )
 198        firstConv = (3,3,3)
 199        firstStrides=(1,1,1)
 200        smashConv=(1,1,1)
 201    def up_block_2d(L, number_of_filters=64, kernel_size=(12, 12), strides=(8, 8),
 202                    include_dense_convolution_layer=True):
 203        if include_dense_convolution_layer == True:
 204            L = myconv(filters = number_of_filters,
 205                       use_bias=True,
 206                       kernel_size=smashConv,
 207                       strides=firstStrides,
 208                       padding='same')(L)
 209            L = PReLU(alpha_initializer='zero',
 210                      shared_axes=shax)(L)
 211        # Scale up
 212        if idim == 2:
 213            H0 = myupsampling( size = strides, interpolation=interpolation )(L)
 214        if idim == 3:
 215            H0 = myupsampling( size = strides )(L)
 216        H0 = myconv(filters=number_of_filters,
 217                    kernel_size=firstConv,
 218                    strides=firstStrides,
 219                    use_bias=True,
 220                    padding='same')(H0)
 221        H0 = PReLU(alpha_initializer='zero',
 222                   shared_axes=shax)(H0)
 223        # Scale down
 224        L0 = myconv(filters=number_of_filters,
 225                    kernel_size=kernel_size,
 226                    strides=strides,
 227                    kernel_initializer='glorot_uniform',
 228                    padding='same')(H0)
 229        L0 = PReLU(alpha_initializer='zero',
 230                   shared_axes=shax)(L0)
 231        # Residual
 232        E = Subtract()([L0, L])
 233        # Scale residual up
 234        if idim == 2:
 235            H1 = myupsampling( size = strides, interpolation=interpolation  )(E)
 236        if idim == 3:
 237            H1 = myupsampling( size = strides )(E)
 238        H1 = myconv(filters=number_of_filters,
 239                    kernel_size=firstConv,
 240                    strides=firstStrides,
 241                    use_bias=True,
 242                    padding='same')(H1)
 243        H1 = PReLU(alpha_initializer='zero',
 244                   shared_axes=shax)(H1)
 245        # Output feature map
 246        up_block = Add()([H0, H1])
 247        return up_block
 248    def down_block_2d(H, number_of_filters=64, kernel_size=(12, 12), strides=(8, 8),
 249                    include_dense_convolution_layer=True):
 250        if include_dense_convolution_layer == True:
 251            H = myconv(filters = number_of_filters,
 252                       use_bias=True,
 253                       kernel_size=smashConv,
 254                       strides=firstStrides,
 255                       padding='same')(H)
 256            H = PReLU(alpha_initializer='zero',
 257                      shared_axes=shax)(H)
 258        # Scale down
 259        L0 = myconv(filters=number_of_filters,
 260                    kernel_size=kernel_size,
 261                    strides=strides,
 262                    kernel_initializer='glorot_uniform',
 263                    padding='same')(H)
 264        L0 = PReLU(alpha_initializer='zero',
 265                   shared_axes=shax)(L0)
 266        # Scale up
 267        if idim == 2:
 268            H0 = myupsampling( size = strides, interpolation=interpolation )(L0)
 269        if idim == 3:
 270            H0 = myupsampling( size = strides )(L0)
 271        H0 = myconv(filters=number_of_filters,
 272                    kernel_size=firstConv,
 273                    strides=firstStrides,
 274                    use_bias=True,
 275                    padding='same')(H0)
 276        H0 = PReLU(alpha_initializer='zero',
 277                   shared_axes=shax)(H0)
 278        # Residual
 279        E = Subtract()([H0, H])
 280        # Scale residual down
 281        L1 = myconv(filters=number_of_filters,
 282                    kernel_size=kernel_size,
 283                    strides=strides,
 284                    kernel_initializer='glorot_uniform',
 285                    padding='same')(E)
 286        L1 = PReLU(alpha_initializer='zero',
 287                   shared_axes=shax)(L1)
 288        # Output feature map
 289        down_block = Add()([L0, L1])
 290        return down_block
 291    inputs = Input(shape=input_image_size)
 292    # Initial feature extraction
 293    model = myconv(filters=number_of_feature_filters,
 294                   kernel_size=firstConv,
 295                   strides=firstStrides,
 296                   padding='same',
 297                   kernel_initializer='glorot_uniform')(inputs)
 298    model = PReLU(alpha_initializer='zero',
 299                  shared_axes=shax)(model)
 300    # Feature smashing
 301    model = myconv(filters=number_of_base_filters,
 302                   kernel_size=smashConv,
 303                   strides=firstStrides,
 304                   padding='same',
 305                   kernel_initializer='glorot_uniform')(model)
 306    model = PReLU(alpha_initializer='zero',
 307                  shared_axes=shax)(model)
 308    # Back projection
 309    up_projection_blocks = []
 310    down_projection_blocks = []
 311    model = up_block_2d(model, number_of_filters=number_of_base_filters,
 312      kernel_size=convolution_kernel_size, strides=strides)
 313    up_projection_blocks.append(model)
 314    for i in range(number_of_back_projection_stages):
 315        if i == 0:
 316            model = down_block_2d(model, number_of_filters=number_of_base_filters,
 317              kernel_size=convolution_kernel_size, strides=strides)
 318            down_projection_blocks.append(model)
 319            model = up_block_2d(model, number_of_filters=number_of_base_filters,
 320              kernel_size=convolution_kernel_size, strides=strides)
 321            up_projection_blocks.append(model)
 322            model = Concatenate()(up_projection_blocks)
 323        else:
 324            model = down_block_2d(model, number_of_filters=number_of_base_filters,
 325              kernel_size=convolution_kernel_size, strides=strides,
 326              include_dense_convolution_layer=True)
 327            down_projection_blocks.append(model)
 328            model = Concatenate()(down_projection_blocks)
 329            model = up_block_2d(model, number_of_filters=number_of_base_filters,
 330              kernel_size=convolution_kernel_size, strides=strides,
 331              include_dense_convolution_layer=True)
 332            up_projection_blocks.append(model)
 333            model = Concatenate()(up_projection_blocks)
 334    outputs = myconv(filters=number_of_outputs,
 335                     kernel_size=last_convolution,
 336                     strides=firstStrides,
 337                     padding = 'same',
 338                     kernel_initializer = "glorot_uniform")(model)
 339    if number_of_loss_functions == 1:
 340        deep_back_projection_network_model = Model(inputs=inputs, outputs=outputs)
 341    else:
 342        outputList=[]
 343        for k in range(number_of_loss_functions):
 344            outputList.append(outputs)
 345        deep_back_projection_network_model = Model(inputs=inputs, outputs=outputList)
 346    return deep_back_projection_network_model
 347
 348
 349# generate a random corner index for a patch
 350
 351def get_random_base_ind( full_dims, patchWidth, off=8 ):
 352    """
 353    Generates a random top-left corner index for a patch.
 354
 355    This utility function computes a valid starting index (e.g., [x, y, z])
 356    for extracting a patch from a larger volume, ensuring the patch fits entirely
 357    within the volume's boundaries, accounting for an offset.
 358
 359    Parameters
 360    ----------
 361    full_dims : tuple or list
 362        The dimensions of the full volume (e.g., img.shape).
 363
 364    patchWidth : tuple or list
 365        The dimensions of the patch to be extracted.
 366
 367    off : int, optional
 368        An offset from the edge of the volume to avoid sampling near borders.
 369        Default is 8.
 370
 371    Returns
 372    -------
 373    list
 374        A list of integers representing the starting coordinates for the patch.
 375    """
 376    baseInd = [None,None,None]
 377    for k in range(3):
 378        baseInd[k]=random.sample( range( off, full_dims[k]-1-patchWidth[k] ), 1 )[0]
 379    return baseInd
 380
 381
 382# extract a random patch
 383def get_random_patch( img, patchWidth ):
 384    """
 385    Extracts a random patch from an image with non-zero variance.
 386
 387    This function repeatedly samples a random patch of a specified width from
 388    the input image until it finds one where the standard deviation of pixel
 389    intensities is greater than zero. This is useful for avoiding blank or
 390    uniform patches during training data generation.
 391
 392    Parameters
 393    ----------
 394    img : ants.ANTsImage
 395        The source image from which to extract a patch.
 396
 397    patchWidth : tuple or list
 398        The desired dimensions of the output patch.
 399
 400    Returns
 401    -------
 402    ants.ANTsImage
 403        A randomly extracted patch from the input image.
 404    """
 405    mystd = 0
 406    while mystd == 0:
 407        inds = get_random_base_ind( full_dims = img.shape, patchWidth=patchWidth, off=8 )
 408        hinds = [None,None,None]
 409        for k in range(len(inds)):
 410            hinds[k] = inds[k] + patchWidth[k]
 411        myimg = ants.crop_indices( img, inds, hinds )
 412        mystd = myimg.std()
 413    return myimg
 414
 415def get_random_patch_pair( img, img2, patchWidth ):
 416    """
 417    Extracts a corresponding random patch from a pair of images.
 418
 419    This function finds a single random location and extracts a patch of the
 420    same size and position from two different input images. It ensures that
 421    both extracted patches have non-zero variance. This is useful for creating
 422    paired training data (e.g., low-res and high-res images).
 423
 424    Parameters
 425    ----------
 426    img : ants.ANTsImage
 427        The first source image.
 428
 429    img2 : ants.ANTsImage
 430        The second source image, spatially aligned with the first.
 431
 432    patchWidth : tuple or list
 433        The desired dimensions of the output patches.
 434
 435    Returns
 436    -------
 437    tuple of ants.ANTsImage
 438        A tuple containing two corresponding patches: (patch_from_img, patch_from_img2).
 439    """
 440    mystd = mystd2 = 0
 441    ct = 0
 442    while mystd == 0 or mystd2 == 0:
 443        inds = get_random_base_ind( full_dims = img.shape, patchWidth=patchWidth, off=8  )
 444        hinds = [None,None,None]
 445        for k in range(len(inds)):
 446            hinds[k] = inds[k] + patchWidth[k]
 447        myimg = ants.crop_indices( img, inds, hinds )
 448        myimg2 = ants.crop_indices( img2, inds, hinds )
 449        mystd = myimg.std()
 450        mystd2 = myimg2.std()
 451        ct = ct + 1
 452        if ( ct > 20 ):
 453            return myimg, myimg2
 454    return myimg, myimg2
 455
 456def pseudo_3d_vgg_features( inshape = [128,128,128], layer = 4, angle=0, pretrained=True, verbose=False ):
 457    """
 458    Creates a pseudo-3D VGG feature extractor from a pre-trained 2D VGG model.
 459
 460    This function constructs a 3D VGG-style network and initializes its weights
 461    by "stretching" the weights from a pre-trained 2D VGG19 model (trained on
 462    ImageNet) along a specified axis. This is a technique to transfer 2D
 463    perceptual knowledge to a 3D domain for tasks like perceptual loss.
 464
 465    Parameters
 466    ----------
 467    inshape : list of int, optional
 468        The input shape of the 3D volume, e.g., `[128, 128, 128]`. Default is `[128,128,128]`.
 469
 470    layer : int, optional
 471        The block number of the VGG network from which to extract features. For
 472        VGG19, this corresponds to block `layer` (e.g., layer=4 means 'block4_conv...').
 473        Default is 4.
 474
 475    angle : int, optional
 476        The axis along which to project the 2D weights:
 477        - 0: Axial plane (stretches along Z)
 478        - 1: Coronal plane (stretches along Y)
 479        - 2: Sagittal plane (stretches along X)
 480        Default is 0.
 481
 482    pretrained : bool, optional
 483        If True, loads the stretched ImageNet weights. If False, the model is
 484        randomly initialized. Default is True.
 485
 486    verbose : bool, optional
 487        If True, prints information about the layers being used. Default is False.
 488
 489    Returns
 490    -------
 491    tf.keras.Model
 492        A Keras model that takes a 3D volume as input and outputs the pseudo-3D
 493        feature map from the specified layer and angle.
 494    """
 495    def getLayerScaleFactorForTransferLearning( k, w3d, w2d ):
 496        myfact = np.round( np.prod( w3d[k].shape ) / np.prod(  w2d[k].shape) )
 497        return myfact
 498    vgg19 = tf.keras.applications.VGG19(
 499            include_top = False, weights = "imagenet",
 500            input_shape = [inshape[0],inshape[1],3],
 501            classes = 1000 )
 502    def findLayerIndex( layerName, mdl ):
 503          for k in range( len( mdl.layers ) ):
 504            if layerName == mdl.layers[k].name :
 505                return k - 1
 506          return None
 507    layer_index = layer-1 # findLayerIndex( 'block2_conv2', vgg19 )
 508    vggmodelRaw = antspynet.create_vgg_model_3d(
 509            [inshape[0],inshape[1],inshape[2],1],
 510            number_of_classification_labels = 1000,
 511            layers = [1, 2, 3, 4, 4],
 512            lowest_resolution = 64,
 513            convolution_kernel_size= (3, 3, 3), pool_size = (2, 2, 2),
 514            strides = (2, 2, 2), number_of_dense_units= 4096, dropout_rate = 0,
 515            style = 19, mode = "classification")
 516    if verbose:
 517        print( vggmodelRaw.layers[layer_index] )
 518        print( vggmodelRaw.layers[layer_index].name )
 519        print( vgg19.layers[layer_index] )
 520        print( vgg19.layers[layer_index].name )
 521    feature_extractor_2d = tf.keras.Model(
 522            inputs = vgg19.input,
 523            outputs = vgg19.layers[layer_index].output)
 524    feature_extractor = tf.keras.Model(
 525            inputs = vggmodelRaw.input,
 526            outputs = vggmodelRaw.layers[layer_index].output)
 527    wts_2d = feature_extractor_2d.weights
 528    wts = feature_extractor.weights
 529    def checkwtshape( a, b ):
 530        if len(a.shape) != len(b.shape):
 531                return False
 532        for j in range(len(a.shape)):
 533            if a.shape[j] != b.shape[j]:
 534                return False
 535        return True
 536    for ww in range(len(wts)):
 537        wts[ww]=wts[ww].numpy()
 538        wts_2d[ww]=wts_2d[ww].numpy()
 539        if checkwtshape( wts[ww], wts_2d[ww] ) and ww != 0:
 540            wts[ww]=wts_2d[ww]
 541        elif ww != 0:
 542            # FIXME - should allow doing this across different angles
 543            if angle == 0:
 544                wts[ww][:,:,0,:,:]=wts_2d[ww]/3.0
 545                wts[ww][:,:,1,:,:]=wts_2d[ww]/3.0
 546                wts[ww][:,:,2,:,:]=wts_2d[ww]/3.0
 547            if angle == 1:
 548                wts[ww][:,0,:,:,:]=wts_2d[ww]/3.0
 549                wts[ww][:,1,:,:,:]=wts_2d[ww]/3.0
 550                wts[ww][:,2,:,:,:]=wts_2d[ww]/3.0
 551            if angle == 2:
 552                wts[ww][0,:,:,:,:]=wts_2d[ww]/3.0
 553                wts[ww][1,:,:,:,:]=wts_2d[ww]/3.0
 554                wts[ww][2,:,:,:,:]=wts_2d[ww]/3.0
 555        else:
 556            wts[ww][:,:,:,0,:]=wts_2d[ww]
 557    if pretrained:
 558        feature_extractor.set_weights( wts )
 559        newinput = tf.keras.layers.Rescaling(  255.0, -127.5  )( feature_extractor.input )
 560        feature_extractor2 = feature_extractor( newinput )
 561        feature_extractor = tf.keras.Model( feature_extractor.input, feature_extractor2 )
 562    return feature_extractor
 563
 564def pseudo_3d_vgg_features_unbiased( inshape = [128,128,128], layer = 4, verbose=False ):
 565    """
 566    Create a pseudo-3D VGG-style feature extractor by aggregating axial, coronal,
 567    and sagittal VGG feature representations.
 568
 569    This model extracts features along each principal axis using pre-trained 2D
 570    VGG-style networks and concatenates them to form an unbiased pseudo-3D feature space.
 571
 572    Parameters
 573    ----------
 574    inshape : list of int, optional
 575        The input shape of the 3D volume, default is [128, 128, 128].
 576
 577    layer : int, optional
 578        The VGG feature layer to extract. Higher values correspond to deeper
 579        layers in the pseudo-3D VGG backbone.
 580
 581    verbose : bool, optional
 582        If True, prints debug messages during model construction.
 583
 584    Returns
 585    -------
 586    tf.keras.Model
 587        A TensorFlow Keras model that takes a 3D input volume and outputs the
 588        concatenated pseudo-3D feature representation from the specified layer.
 589
 590    Notes
 591    -----
 592    This is useful for perceptual loss or feature comparison in super-resolution
 593    and image synthesis tasks. The same input is processed in three anatomical
 594    planes (axial, coronal, sagittal), and features are concatenated.
 595
 596    See Also
 597    --------
 598    pseudo_3d_vgg_features : Generates VGG features from a single anatomical plane.
 599    """
 600    f = [
 601        pseudo_3d_vgg_features( inshape, layer, angle=0, pretrained=True, verbose=verbose ),
 602        pseudo_3d_vgg_features( inshape, layer, angle=1, pretrained=True ),
 603        pseudo_3d_vgg_features( inshape, layer, angle=2, pretrained=True ) ]
 604    f1=f[0].inputs
 605    f0o=f[0]( f1 )
 606    f1o=f[1]( f1 )
 607    f2o=f[2]( f1 )
 608    catter = tf.keras.layers.concatenate( [f0o, f1o, f2o ])
 609    feature_extractor = tf.keras.Model( f1, catter )
 610    return feature_extractor
 611
 612def get_grader_feature_network( layer=6 ):
 613    """
 614    Load and extract a ResNet-based feature subnetwork for perceptual loss or quality grading.
 615
 616    This function loads a pre-trained 3D ResNet model ("grader") used for
 617    perceptual feature extraction and returns a subnetwork that outputs activations
 618    from a specified internal layer.
 619
 620    Parameters
 621    ----------
 622    layer : int, optional
 623        The index of the internal ResNet layer whose output should be used as
 624        the feature representation. Default is layer 6.
 625
 626    Returns
 627    -------
 628    tf.keras.Model
 629        A Keras model that outputs features from the specified layer of the
 630        pre-trained 3D ResNet grader model.
 631
 632    Raises
 633    ------
 634    Exception
 635        If the pre-trained weights file (`resnet_grader.h5`) is not found.
 636
 637    Notes
 638    -----
 639    The pre-trained weights should be located in: `~/.antspyt1w/resnet_grader.keras`
 640
 641    This model is typically used to compute perceptual loss by comparing
 642    intermediate activations between target and prediction volumes.
 643
 644    See Also
 645    --------
 646    antspynet.create_resnet_model_3d : Constructs the base ResNet model.
 647    """
 648    grader = antspynet.create_resnet_model_3d(
 649        [None,None,None,1],
 650        lowest_resolution = 32,
 651        number_of_outputs = 4,
 652        cardinality = 1,
 653        squeeze_and_excite = False )
 654    # the folder and data below as available from antspyt1w get_data
 655    graderfn = os.path.expanduser( "~/.antspyt1w/resnet_grader.h5" )
 656    if not exists( graderfn ):
 657        raise Exception("graderfn " + graderfn + " does not exist")
 658    grader.load_weights( graderfn)
 659    #    feature_extractor_23 = tf.keras.Model( inputs=grader.inputs, outputs=grader.layers[23].output )
 660    #   feature_extractor_44 = tf.keras.Model( inputs=grader.inputs, outputs=grader.layers[44].output )
 661    return tf.keras.Model( inputs=grader.inputs, outputs=grader.layers[layer].output )
 662
 663
 664def default_dbpn(
 665    strider, # length should equal dimensionality
 666    dimensionality = 3,
 667    nfilt=64,
 668    nff = 256,
 669    convn = 6,
 670    lastconv = 3,
 671    nbp=7,
 672    nChannelsIn=1,
 673    nChannelsOut=1,
 674    option = None,
 675    intensity_model=None,
 676    segmentation_model=None,
 677    sigmoid_second_channel=False,
 678    clone_intensity_to_segmentation=False,
 679    pro_seg = 0,
 680    freeze = False,
 681    verbose=False
 682 ):
 683    """
 684    Constructs a DBPN model based on input parameters, and can optionally
 685    use external models for intensity or segmentation processing.
 686
 687    Args:
 688        strider (list): List of strides, length must match `dimensionality`.
 689        dimensionality (int): Number of dimensions (2 or 3). Default is 3.
 690        nfilt (int): Number of base filters. Default is 64.
 691        nff (int): Number of feature filters. Default is 256.
 692        convn (int): Convolution kernel size. Default is 6.
 693        lastconv (int): Size of the last convolution. Default is 3.
 694        nbp (int): Number of back projection stages. Default is 7.
 695        nChannelsIn (int): Number of input channels. Default is 1.
 696        nChannelsOut (int): Number of output channels. Default is 1.
 697        option (str): Model size option ('tiny', 'small', 'medium', 'large'). Default is None.
 698        intensity_model (tf.keras.Model): Optional external intensity model.
 699        segmentation_model (tf.keras.Model): Optional external segmentation model.
 700        sigmoid_second_channel (bool): If True, applies sigmoid to second channel in output.
 701        clone_intensity_to_segmentation (bool): If True, clones intensity model weights to segmentation.
 702        pro_seg (int): If greater than 0, adds a segmentation arm.
 703        freeze (bool): If True, freezes the layers of the intensity/segmentation models.
 704        verbose (bool): If True, prints detailed logs.
 705
 706    Returns:
 707        Model: A Keras model based on the specified configuration.
 708
 709    Raises:
 710        Exception: If `len(strider)` is not equal to `dimensionality`.
 711    """
 712    if option == 'tiny':
 713        nfilt=32
 714        nff = 64
 715        convn = 3
 716        lastconv = 1
 717        nbp=2
 718    elif option == 'small':
 719        nfilt=32
 720        nff = 64
 721        convn = 6
 722        lastconv = 3
 723        nbp=4
 724    elif option == 'medium':
 725        nfilt=64
 726        nff = 128
 727        convn = 6
 728        lastconv = 3
 729        nbp=4
 730    else:
 731        option='large'
 732    if verbose:
 733        print("Building mode of size: " + option)
 734        if intensity_model is not None:
 735            print("user-passed intensity model will be frozen - only segmentation will train")
 736        if segmentation_model is not None:
 737            print("user-passed segmentation model will be frozen - only intensity will train")
 738
 739    if len(strider) != dimensionality:
 740        raise Exception("len(strider) != dimensionality")
 741    # **model instantiation**: these are close to defaults for the 2x network.<br>
 742    # empirical evidence suggests that making covolutions and strides evenly<br>
 743    # divisible by each other reduces artifacts.  2*3=6.
 744    # ofn='./models/dsr3d_'+str(strider)+'up_' + str(nfilt) + '_' + str( nff ) + '_' + str(convn)+ '_' + str(lastconv)+ '_' + str(os.environ['CUDA_VISIBLE_DEVICES'])+'_v0.0.keras'
 745    if dimensionality == 2 :
 746        mdl = dbpn( (None,None,nChannelsIn),
 747            number_of_outputs=nChannelsOut,
 748            number_of_base_filters=nfilt,
 749            number_of_feature_filters=nff,
 750            number_of_back_projection_stages=nbp,
 751            convolution_kernel_size=(convn, convn),
 752            strides=(strider[0], strider[1]),
 753            last_convolution=(lastconv, lastconv),
 754            number_of_loss_functions=1,
 755            interpolation='nearest')
 756    if dimensionality == 3 :
 757        mdl = dbpn( (None,None,None,nChannelsIn),
 758            number_of_outputs=nChannelsOut,
 759            number_of_base_filters=nfilt,
 760            number_of_feature_filters=nff,
 761            number_of_back_projection_stages=nbp,
 762            convolution_kernel_size=(convn, convn, convn),
 763            strides=(strider[0], strider[1], strider[2]),
 764            last_convolution=(lastconv, lastconv, lastconv), number_of_loss_functions=1, interpolation='nearest')
 765    if sigmoid_second_channel and pro_seg != 0 :
 766        if dimensionality == 2 :
 767            input_image_size = (None,None,2)
 768            if intensity_model is None:
 769                intensity_model = dbpn( (None,None,1),
 770                    number_of_outputs=1,
 771                    number_of_base_filters=nfilt,
 772                    number_of_feature_filters=nff,
 773                    number_of_back_projection_stages=nbp,
 774                    convolution_kernel_size=(convn, convn),
 775                    strides=(strider[0], strider[1]),
 776                    last_convolution=(lastconv, lastconv),
 777                    number_of_loss_functions=1,
 778                    interpolation='nearest')
 779            else:
 780                if freeze:
 781                    for layer in intensity_model.layers:
 782                        layer.trainable = False
 783            if segmentation_model is None:
 784                segmentation_model = dbpn( (None,None,1),
 785                        number_of_outputs=1,
 786                        number_of_base_filters=nfilt,
 787                        number_of_feature_filters=nff,
 788                        number_of_back_projection_stages=nbp,
 789                        convolution_kernel_size=(convn, convn),
 790                        strides=(strider[0], strider[1]),
 791                        last_convolution=(lastconv, lastconv),
 792                        number_of_loss_functions=1, interpolation='linear')
 793            else:
 794                if freeze:
 795                    for layer in segmentation_model.layers:
 796                        layer.trainable = False
 797        if dimensionality == 3 :
 798            input_image_size = (None,None,None,2)
 799            if intensity_model is None:
 800                intensity_model = dbpn( (None,None,None,1),
 801                    number_of_outputs=1,
 802                    number_of_base_filters=nfilt,
 803                    number_of_feature_filters=nff,
 804                    number_of_back_projection_stages=nbp,
 805                    convolution_kernel_size=(convn, convn, convn),
 806                    strides=(strider[0], strider[1], strider[2]),
 807                    last_convolution=(lastconv, lastconv, lastconv),
 808                    number_of_loss_functions=1, interpolation='nearest')
 809            else:
 810                if freeze:
 811                    for layer in intensity_model.layers:
 812                        layer.trainable = False
 813            if segmentation_model is None:
 814                segmentation_model = dbpn( (None,None,None,1),
 815                        number_of_outputs=1,
 816                        number_of_base_filters=nfilt,
 817                        number_of_feature_filters=nff,
 818                        number_of_back_projection_stages=nbp,
 819                        convolution_kernel_size=(convn, convn, convn),
 820                        strides=(strider[0], strider[1], strider[2]),
 821                        last_convolution=(lastconv, lastconv, lastconv),
 822                        number_of_loss_functions=1, interpolation='linear')
 823            else:
 824                if freeze:
 825                    for layer in segmentation_model.layers:
 826                        layer.trainable = False
 827        if verbose:
 828            print( "len intensity_model layers : " + str( len( intensity_model.layers )))
 829            print( "len intensity_model weights : " + str( len( intensity_model.weights )))
 830            print( "len segmentation_model layers : " + str( len( segmentation_model.layers )))
 831            print( "len segmentation_model weights : " + str( len( segmentation_model.weights )))
 832        if clone_intensity_to_segmentation:
 833            for k in range(len( segmentation_model.weights )):
 834                if k < len( intensity_model.weights ):
 835                    if intensity_model.weights[k].shape == segmentation_model.weights[k].shape:
 836                        segmentation_model.weights[k] = intensity_model.weights[k]
 837        inputs = tf.keras.Input(shape=input_image_size)
 838        insplit = tf.split( inputs, 2, dimensionality+1)
 839        outputs = [
 840            intensity_model( insplit[0] ),
 841            tf.nn.sigmoid( segmentation_model( insplit[1] ) ) ]
 842        mdlout = tf.concat( outputs, axis=dimensionality+1 )
 843        return Model(inputs=inputs, outputs=mdlout )
 844    if pro_seg > 0 and intensity_model is not None:
 845        if verbose and freeze:
 846            print("Add a segmentation arm to the end. freeze intensity. intensity_model(seg) => conv => sigmoid")
 847        if verbose and not freeze:
 848            print("Add a segmentation arm to the end. freeze intensity. intensity_model(seg) => conv => sigmoid")
 849        if freeze:
 850            for layer in intensity_model.layers:
 851                layer.trainable = False
 852        if dimensionality == 2 :
 853            input_image_size = (None,None,2)
 854        elif dimensionality == 3 :
 855            input_image_size = (None, None,None,2)
 856        if dimensionality == 2:
 857            myconv = Conv2D
 858            firstConv = (convn,convn)
 859            firstStrides=(1,1)
 860            smashConv=(pro_seg,pro_seg)
 861        if dimensionality == 3:
 862            myconv = Conv3D
 863            firstConv = (convn,convn,convn)
 864            firstStrides=(1,1,1)
 865            smashConv=(pro_seg,pro_seg,pro_seg)
 866        inputs = tf.keras.Input(shape=input_image_size)
 867        insplit = tf.split( inputs, 2, dimensionality+1)
 868        # define segmentation arm
 869        seggit = intensity_model( insplit[1] )
 870        L0 = myconv(filters=nff,
 871                    kernel_size=firstConv,
 872                    strides=firstStrides,
 873                    kernel_initializer='glorot_uniform',
 874                    padding='same')(seggit)
 875        L1 = myconv(filters=nff,
 876                    kernel_size=firstConv,
 877                    strides=firstStrides,
 878                    kernel_initializer='glorot_uniform',
 879                    padding='same')(L0)
 880        L2 = myconv(filters=1,
 881                    kernel_size=smashConv,
 882                    strides=firstStrides,
 883                    kernel_initializer='glorot_uniform',
 884                    padding='same')(L1)
 885        outputs = [
 886            intensity_model( insplit[0] ),
 887            tf.nn.sigmoid( L2 ) ]
 888        mdlout = tf.concat( outputs, axis=dimensionality+1 )
 889        return Model(inputs=inputs, outputs=mdlout )
 890    return mdl
 891
 892def image_patch_training_data_from_filenames(
 893    filenames,
 894    target_patch_size,
 895    target_patch_size_low,
 896    nPatches = 128,
 897    istest   = False,
 898    patch_scaler=True,
 899    to_tensorflow = False,
 900    verbose = False
 901    ):
 902    """
 903    Generates a batch of paired high- and low-resolution image patches for training.
 904
 905    This function creates training data by taking a list of high-resolution source
 906    images, extracting random patches, and then downsampling them to create
 907    low-resolution counterparts. This provides the (input, ground_truth) pairs
 908    needed to train a super-resolution model.
 909
 910    Parameters
 911    ----------
 912    filenames : list of str
 913        A list of file paths to the high-resolution source images.
 914
 915    target_patch_size : tuple or list of int
 916        The dimensions of the high-resolution (ground truth) patch to extract,
 917        e.g., `(128, 128, 128)`.
 918
 919    target_patch_size_low : tuple or list of int
 920        The dimensions of the low-resolution (input) patch to generate. The ratio
 921        between `target_patch_size` and `target_patch_size_low` determines the
 922        super-resolution factor.
 923
 924    nPatches : int, optional
 925        The number of patch pairs to generate in this batch. Default is 128.
 926
 927    istest : bool, optional
 928        If True, the function also generates a third output array containing patches
 929        that have been naively upsampled using linear interpolation. This is useful
 930        for calculating baseline evaluation metrics (e.g., PSNR) against which the
 931        model's performance can be compared. Default is False.
 932
 933    patch_scaler : bool, optional
 934        If True, scales the intensity of each high-resolution patch to the [0, 1]
 935        range before creating the downsampled version. This can help with
 936        training stability. Default is True.
 937
 938    to_tensorflow : bool, optional
 939        If True, casts the output NumPy arrays to TensorFlow tensors. Default is False.
 940
 941    verbose : bool, optional
 942        If True, prints progress messages during patch generation. Default is False.
 943
 944    Returns
 945    -------
 946    tuple
 947        A tuple of NumPy arrays or TensorFlow tensors.
 948        - If `istest` is False: `(patchesResam, patchesOrig)`
 949            - `patchesResam`: The batch of low-resolution input patches (X_train).
 950            - `patchesOrig`: The batch of high-resolution ground truth patches (y_train).
 951        - If `istest` is True: `(patchesResam, patchesOrig, patchesUp)`
 952            - `patchesUp`: The batch of baseline, linearly-upsampled patches.
 953    """
 954    if verbose:
 955        print("begin image_patch_training_data_from_filenames")
 956    tardim = len( target_patch_size )
 957    strider = []
 958    for j in range( tardim ):
 959        strider.append( np.round( target_patch_size[j]/target_patch_size_low[j]) )
 960    if tardim == 3:
 961        shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],target_patch_size[2],1)
 962        shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],target_patch_size_low[2],1)
 963    if tardim == 2:
 964        shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],1)
 965        shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],1)
 966    patchesOrig = np.zeros(shape=shaperhi)
 967    patchesResam = np.zeros(shape=shaperlo)
 968    patchesUp = None
 969    if istest:
 970        patchesUp = np.zeros(shape=patchesOrig.shape)
 971    for myn in range(nPatches):
 972            if verbose:
 973                print(myn)
 974            imgfn = random.sample( filenames, 1 )[0]
 975            if verbose:
 976                print(imgfn)
 977            img = ants.image_read( imgfn ).iMath("Normalize")
 978            if img.components > 1:
 979                img = ants.split_channels(img)[0]
 980            img = ants.crop_image( img, ants.threshold_image( img, 0.05, 1 ) )
 981            ants.set_origin( img, ants.get_center_of_mass(img) )
 982            img = ants.iMath(img,"Normalize")
 983            spc = ants.get_spacing( img )
 984            newspc = []
 985            for jj in range(len(spc)):
 986                newspc.append(spc[jj]*strider[jj])
 987            interp_type = random.choice( [0,1] )
 988            if True:
 989                imgp = get_random_patch( img, target_patch_size )
 990                imgpmin = imgp.min()
 991                if patch_scaler:
 992                    imgp = imgp - imgpmin
 993                    imgpmax = imgp.max()
 994                    if imgpmax > 0 :
 995                        imgp = imgp / imgpmax
 996                rimgp = ants.resample_image( imgp, newspc, use_voxels = False, interp_type=interp_type  )
 997                if istest:
 998                    rimgbi = ants.resample_image( rimgp, spc, use_voxels = False, interp_type=0  )
 999                if tardim == 3:
1000                    patchesOrig[myn,:,:,:,0] = imgp.numpy()
1001                    patchesResam[myn,:,:,:,0] = rimgp.numpy()
1002                    if istest:
1003                        patchesUp[myn,:,:,:,0] = rimgbi.numpy()
1004                if tardim == 2:
1005                    patchesOrig[myn,:,:,0] = imgp.numpy()
1006                    patchesResam[myn,:,:,0] = rimgp.numpy()
1007                    if istest:
1008                        patchesUp[myn,:,:,0] = rimgbi.numpy()
1009    if to_tensorflow:
1010        patchesOrig = tf.cast( patchesOrig, "float32")
1011        patchesResam = tf.cast( patchesResam, "float32")
1012    if istest:
1013        if to_tensorflow:
1014            patchesUp = tf.cast( patchesUp, "float32")
1015    return patchesResam, patchesOrig, patchesUp
1016
1017
1018def seg_patch_training_data_from_filenames(
1019    filenames,
1020    target_patch_size,
1021    target_patch_size_low,
1022    nPatches = 128,
1023    istest   = False,
1024    patch_scaler=True,
1025    to_tensorflow = False,
1026    verbose = False
1027    ):
1028    """
1029    Generates a batch of paired training data containing both images and segmentations.
1030
1031    This function extends `image_patch_training_data_from_filenames` by adding a
1032    second channel to the data. For each extracted image patch, it also generates
1033    a corresponding segmentation mask using Otsu's thresholding. This is useful for
1034    training multi-task models that perform super-resolution on both an image and
1035    its associated segmentation simultaneously.
1036
1037    Parameters
1038    ----------
1039    filenames : list of str
1040        A list of file paths to the high-resolution source images.
1041
1042    target_patch_size : tuple or list of int
1043        The dimensions of the high-resolution patch, e.g., `(128, 128, 128)`.
1044
1045    target_patch_size_low : tuple or list of int
1046        The dimensions of the low-resolution input patch.
1047
1048    nPatches : int, optional
1049        The number of patch pairs to generate. Default is 128.
1050
1051    istest : bool, optional
1052        If True, also generates a third output array containing baseline upsampled
1053        intensity images (channel 0 only). Default is False.
1054
1055    patch_scaler : bool, optional
1056        If True, scales the intensity of each image patch to the [0, 1] range.
1057        Default is True.
1058
1059    to_tensorflow : bool, optional
1060        If True, casts the output NumPy arrays to TensorFlow tensors. Default is False.
1061
1062    verbose : bool, optional
1063        If True, prints progress messages. Default is False.
1064
1065    Returns
1066    -------
1067    tuple
1068        A tuple of multi-channel NumPy arrays or TensorFlow tensors. The structure
1069        is the same as `image_patch_training_data_from_filenames`, but each
1070        array has a channel dimension of 2:
1071        - Channel 0: The intensity image.
1072        - Channel 1: The binary segmentation mask.
1073    """
1074    if verbose:
1075        print("begin seg_patch_training_data_from_filenames")
1076    tardim = len( target_patch_size )
1077    strider = []
1078    nchan = 2
1079    for j in range( tardim ):
1080        strider.append( np.round( target_patch_size[j]/target_patch_size_low[j]) )
1081    if tardim == 3:
1082        shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],target_patch_size[2],nchan)
1083        shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],target_patch_size_low[2],nchan)
1084    if tardim == 2:
1085        shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],nchan)
1086        shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],nchan)
1087    patchesOrig = np.zeros(shape=shaperhi)
1088    patchesResam = np.zeros(shape=shaperlo)
1089    patchesUp = None
1090    if istest:
1091        patchesUp = np.zeros(shape=patchesOrig.shape)
1092    for myn in range(nPatches):
1093            if verbose:
1094                print(myn)
1095            imgfn = random.sample( filenames, 1 )[0]
1096            if verbose:
1097                print(imgfn)
1098            img = ants.image_read( imgfn ).iMath("Normalize")
1099            if img.components > 1:
1100                img = ants.split_channels(img)[0]
1101            img = ants.crop_image( img, ants.threshold_image( img, 0.05, 1 ) )
1102            ants.set_origin( img, ants.get_center_of_mass(img) )
1103            img = ants.iMath(img,"Normalize")
1104            spc = ants.get_spacing( img )
1105            newspc = []
1106            for jj in range(len(spc)):
1107                newspc.append(spc[jj]*strider[jj])
1108            interp_type = random.choice( [0,1] )
1109            seg_class = random.choice( [1,2] )
1110            if True:
1111                imgp = get_random_patch( img, target_patch_size )
1112                imgpmin = imgp.min()
1113                if patch_scaler:
1114                    imgp = imgp - imgpmin
1115                    imgpmax = imgp.max()
1116                    if imgpmax > 0 :
1117                        imgp = imgp / imgpmax
1118                segp = ants.threshold_image( imgp, "Otsu", 2 ).threshold_image( seg_class, seg_class )
1119                rimgp = ants.resample_image( imgp, newspc, use_voxels = False, interp_type=interp_type  )
1120                rsegp = ants.resample_image( segp, newspc, use_voxels = False, interp_type=interp_type  )
1121                if istest:
1122                    rimgbi = ants.resample_image( rimgp, spc, use_voxels = False, interp_type=0  )
1123                if tardim == 3:
1124                    patchesOrig[myn,:,:,:,0] = imgp.numpy()
1125                    patchesResam[myn,:,:,:,0] = rimgp.numpy()
1126                    patchesOrig[myn,:,:,:,1] = segp.numpy()
1127                    patchesResam[myn,:,:,:,1] = rsegp.numpy()
1128                    if istest:
1129                        patchesUp[myn,:,:,:,0] = rimgbi.numpy()
1130                if tardim == 2:
1131                    patchesOrig[myn,:,:,0] = imgp.numpy()
1132                    patchesResam[myn,:,:,0] = rimgp.numpy()
1133                    patchesOrig[myn,:,:,1] = segp.numpy()
1134                    patchesResam[myn,:,:,1] = rsegp.numpy()
1135                    if istest:
1136                        patchesUp[myn,:,:,0] = rimgbi.numpy()
1137    if to_tensorflow:
1138        patchesOrig = tf.cast( patchesOrig, "float32")
1139        patchesResam = tf.cast( patchesResam, "float32")
1140    if istest:
1141        if to_tensorflow:
1142            patchesUp = tf.cast( patchesUp, "float32")
1143    return patchesResam, patchesOrig, patchesUp
1144
1145def read( filename ):
1146    """
1147    Reads an image or a NumPy array from a file.
1148
1149    This function acts as a wrapper to intelligently load data. It checks the
1150    file extension to decide whether to use `ants.image_read` for standard
1151    medical image formats (e.g., .nii.gz, .mha) or `numpy.load` for `.npy` files.
1152
1153    Parameters
1154    ----------
1155    filename : str
1156        The full path to the file to be read.
1157
1158    Returns
1159    -------
1160    ants.ANTsImage or np.ndarray
1161        The loaded data object, either as an ANTsImage or a NumPy array.
1162    """
1163    import re
1164    isnpy = len( re.sub( ".npy", "", filename ) ) != len( filename )
1165    if not isnpy:
1166        myoutput = ants.image_read( filename )
1167    else:
1168        myoutput = np.load( filename )
1169    return myoutput
1170
1171
1172def auto_weight_loss( mdl, feature_extractor, x, y, feature=2.0, tv=0.1, verbose=True ):
1173    """
1174    Automatically compute weighting coefficients for a combined loss function
1175    based on intensity (MSE), perceptual similarity (feature), and total variation (TV).
1176
1177    Parameters
1178    ----------
1179    mdl : tf.keras.Model
1180        A trained or untrained model to evaluate predictions on input `x`.
1181
1182    feature_extractor : tf.keras.Model
1183        A model that extracts intermediate features from the input. Commonly a VGG or ResNet
1184        trained on a perceptual task.
1185
1186    x : tf.Tensor
1187        Input batch to the model.
1188
1189    y : tf.Tensor
1190        Ground truth target for `x`, typically a batch of 2D or 3D volumes.
1191
1192    feature : float, optional
1193        Weighting factor for the feature (perceptual) term in the loss. Default is 2.0.
1194
1195    tv : float, optional
1196        Weighting factor for the total variation term in the loss. Default is 0.1.
1197
1198    verbose : bool, optional
1199        If True, prints each component of the loss and its scaled value.
1200
1201    Returns
1202    -------
1203    list of float
1204        A list of computed weights in the order:
1205        `[msq_weight, feature_weight, tv_weight]`
1206
1207    Notes
1208    -----
1209    The total loss (to be used during training) can then be constructed as:
1210
1211        `L = msq_weight * MSE + feature_weight * perceptual_loss + tv_weight * TV`
1212
1213    This function is typically used to balance loss terms before training.
1214    """    
1215    y_pred = mdl( x )
1216    squared_difference = tf.square( y - y_pred)
1217    if len( y.shape ) == 5:
1218            tdim = 3
1219            myax = [1,2,3,4]
1220    if len( y.shape ) == 4:
1221            tdim = 2
1222            myax = [1,2,3]
1223    msqTerm = tf.reduce_mean(squared_difference, axis=myax)
1224    temp1 = feature_extractor(y)
1225    temp2 = feature_extractor(y_pred)
1226    feature_difference = tf.square(temp1-temp2)
1227    featureTerm = tf.reduce_mean(feature_difference, axis=myax)
1228    msqw = 10.0
1229    featw = feature * msqw * msqTerm / featureTerm
1230    mytv = tf.cast( 0.0, 'float32')
1231    if tdim == 3:
1232        for k in range( y_pred.shape[0] ): # BUG not sure why myr fails .... might be old TF version
1233            sqzd = y_pred[k,:,:,:,:]
1234            mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) )
1235    if tdim == 2:
1236        mytv = tf.reduce_mean( tf.image.total_variation( y_pred ) )
1237    tvw = tv * msqw * msqTerm / mytv
1238    if verbose :
1239        print( "MSQ: " + str( msqw * msqTerm ) )
1240        print( "Feat: " + str( featw * featureTerm ) )
1241        print( "Tv: " + str(  mytv * tvw ) )
1242    wts = [msqw,featw.numpy().mean(),tvw.numpy().mean()]
1243    return wts
1244
1245def auto_weight_loss_seg( mdl, feature_extractor, x, y, feature=2.0, tv=0.1, dice=0.5, verbose=True ):
1246    """
1247    Automatically compute weighting coefficients for a combined loss function
1248    that includes MSE, perceptual similarity, total variation, and segmentation Dice loss.
1249
1250    Parameters
1251    ----------
1252    mdl : tf.keras.Model
1253        A segmentation + super-resolution model that outputs both image and label predictions.
1254
1255    feature_extractor : tf.keras.Model
1256        Feature extractor model used to compute perceptual similarity loss.
1257
1258    x : tf.Tensor
1259        Input tensor to the model.
1260
1261    y : tf.Tensor
1262        Target tensor with two channels: [intensity_image, segmentation_label].
1263
1264    feature : float, optional
1265        Relative weight of the perceptual feature loss term. Default is 2.0.
1266
1267    tv : float, optional
1268        Relative weight of the total variation (TV) term. Default is 0.1.
1269
1270    dice : float, optional
1271        Relative weight of the Dice loss term (for segmentation agreement). Default is 0.5.
1272
1273    verbose : bool, optional
1274        If True, prints the scaled values of each component loss.
1275
1276    Returns
1277    -------
1278    list of float
1279        A list of loss term weights in the order:
1280        `[msq_weight, feature_weight, tv_weight, dice_weight]`
1281
1282    Notes
1283    -----
1284    - The input and output tensors must be shaped such that the last axis is 2:
1285      channel 0 is intensity, channel 1 is segmentation.
1286    - This is useful for dual-task networks that predict both high-res images
1287      and associated segmentation masks.
1288
1289    See Also
1290    --------
1291    binary_dice_loss : Computes Dice loss between predicted and ground-truth masks.
1292    """    
1293    y_pred = mdl( x )
1294    if len( y.shape ) == 5:
1295            tdim = 3
1296            myax = [1,2,3,4]
1297    if len( y.shape ) == 4:
1298            tdim = 2
1299            myax = [1,2,3]
1300    y_intensity = tf.split( y, 2, axis=tdim+1 )[0]
1301    y_seg = tf.split( y, 2, axis=tdim+1 )[1]
1302    y_intensity_p = tf.split( y_pred, 2, axis=tdim+1 )[0]
1303    y_seg_p = tf.split( y_pred, 2, axis=tdim+1 )[1]
1304    squared_difference = tf.square( y_intensity - y_intensity_p )
1305    msqTerm = tf.reduce_mean(squared_difference, axis=myax)
1306    temp1 = feature_extractor(y_intensity)
1307    temp2 = feature_extractor(y_intensity_p)
1308    feature_difference = tf.square(temp1-temp2)
1309    featureTerm = tf.reduce_mean(feature_difference, axis=myax)
1310    msqw = 10.0
1311    featw = feature * msqw * msqTerm / featureTerm
1312    mytv = tf.cast( 0.0, 'float32')
1313    if tdim == 3:
1314        for k in range( y_pred.shape[0] ): # BUG not sure why myr fails .... might be old TF version
1315            sqzd = y_pred[k,:,:,:,0]
1316            mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) )
1317    if tdim == 2:
1318        mytv = tf.reduce_mean( tf.image.total_variation( y_pred[:,:,:,0] ) )
1319    tvw = tv * msqw * msqTerm / mytv
1320    mydice = binary_dice_loss( y_seg, y_seg_p )
1321    mydice = tf.reduce_mean( mydice )
1322    dicew = dice * msqw * msqTerm / mydice
1323    dicewt = np.abs( dicew.numpy().mean() )
1324    if verbose :
1325        print( "MSQ: " + str( msqw * msqTerm ) )
1326        print( "Feat: " + str( featw * featureTerm ) )
1327        print( "Tv: " + str(  mytv * tvw ) )
1328        print( "Dice: " + str( mydice * dicewt ) )
1329    wts = [msqw,featw.numpy().mean(),tvw.numpy().mean(), dicewt ]
1330    return wts
1331
1332def numpy_generator( filenames ):
1333    """
1334    A placeholder or stub for a data generator.
1335
1336    This generator yields a tuple of `None` values once and then stops. It is
1337    likely intended as a template or for debugging purposes where a generator
1338    object is required but no actual data needs to be processed.
1339
1340    Parameters
1341    ----------
1342    filenames : any
1343        An argument that is not used by the function.
1344
1345    Yields
1346    ------
1347    tuple
1348        A single tuple `(None, None, None)`.
1349    """
1350    patchesResam=patchesOrig=patchesUp=None
1351    yield (patchesResam, patchesOrig,patchesUp)
1352
1353def image_generator(
1354    filenames,
1355    nPatches,
1356    target_patch_size,
1357    target_patch_size_low,
1358    patch_scaler=True,
1359    istest=False,
1360    verbose = False ):
1361    """
1362    Creates an infinite generator of paired image patches for model training.
1363
1364    This function continuously generates batches of low-resolution (input) and
1365    high-resolution (ground truth) image patches. It is designed to be fed
1366    directly into a Keras `model.fit()` call.
1367
1368    Parameters
1369    ----------
1370    filenames : list of str
1371        List of file paths to the high-resolution source images.
1372    nPatches : int
1373        The number of patch pairs to generate and yield in each batch.
1374    target_patch_size : tuple or list of int
1375        The dimensions of the high-resolution (ground truth) patches.
1376    target_patch_size_low : tuple or list of int
1377        The dimensions of the low-resolution (input) patches.
1378    patch_scaler : bool, optional
1379        If True, scales patch intensities to [0, 1]. Default is True.
1380    istest : bool, optional
1381        If True, the generator will also yield a third item: a baseline
1382        linearly upsampled version of the low-resolution patch for comparison.
1383        Default is False.
1384    verbose : bool, optional
1385        If True, passes verbosity to the underlying patch generation function.
1386        Default is False.
1387
1388    Yields
1389    -------
1390    tuple
1391        A tuple of TensorFlow tensors ready for training or evaluation.
1392        - If `istest` is False: `(low_res_batch, high_res_batch)`
1393        - If `istest` is True: `(low_res_batch, high_res_batch, baseline_upsampled_batch)`
1394
1395    See Also
1396    --------
1397    image_patch_training_data_from_filenames : The function that performs the
1398                                               underlying patch extraction.
1399    """
1400    while True:
1401        patchesResam, patchesOrig, patchesUp = image_patch_training_data_from_filenames(
1402            filenames,
1403            target_patch_size = target_patch_size,
1404            target_patch_size_low = target_patch_size_low,
1405            nPatches = nPatches,
1406            istest   = istest,
1407            patch_scaler=patch_scaler,
1408            to_tensorflow = True,
1409            verbose = verbose )
1410        if istest:
1411            yield (patchesResam, patchesOrig,patchesUp)
1412        yield (patchesResam, patchesOrig)
1413
1414
1415def seg_generator(
1416    filenames,
1417    nPatches,
1418    target_patch_size,
1419    target_patch_size_low,
1420    patch_scaler=True,
1421    istest=False,
1422    verbose = False ):
1423    """
1424    Creates an infinite generator of paired image and segmentation patches.
1425
1426    This function continuously generates batches of multi-channel patches, where
1427    one channel is the intensity image and the other is a segmentation mask.
1428    It is designed for training multi-task super-resolution models.
1429
1430    Parameters
1431    ----------
1432    filenames : list of str
1433        List of file paths to the high-resolution source images.
1434    nPatches : int
1435        The number of patch pairs to generate and yield in each batch.
1436    target_patch_size : tuple or list of int
1437        The dimensions of the high-resolution patches.
1438    target_patch_size_low : tuple or list of int
1439        The dimensions of the low-resolution patches.
1440    patch_scaler : bool, optional
1441        If True, scales the intensity channel of patches to [0, 1]. Default is True.
1442    istest : bool, optional
1443        If True, yields an additional baseline upsampled patch for comparison.
1444        Default is False.
1445    verbose : bool, optional
1446        If True, passes verbosity to the underlying patch generation function.
1447        Default is False.
1448
1449    Yields
1450    -------
1451    tuple
1452        A tuple of multi-channel TensorFlow tensors. Each tensor has two channels:
1453        Channel 0 contains the intensity image, and Channel 1 contains the
1454        segmentation mask.
1455
1456    See Also
1457    --------
1458    seg_patch_training_data_from_filenames : The function that performs the
1459                                             underlying patch extraction.
1460    image_generator : A similar generator for intensity-only data.
1461    """
1462    while True:
1463        patchesResam, patchesOrig, patchesUp = seg_patch_training_data_from_filenames(
1464            filenames,
1465            target_patch_size = target_patch_size,
1466            target_patch_size_low = target_patch_size_low,
1467            nPatches = nPatches,
1468            istest   = istest,
1469            patch_scaler=patch_scaler,
1470            to_tensorflow = True,
1471            verbose = verbose )
1472        if istest:
1473            yield (patchesResam, patchesOrig,patchesUp)
1474        yield (patchesResam, patchesOrig)
1475
1476
1477def train(
1478    mdl,
1479    filenames_train,
1480    filenames_test,
1481    target_patch_size,
1482    target_patch_size_low,
1483    output_prefix,
1484    n_test = 8,
1485    learning_rate=5e-5,
1486    feature_layer = 6,
1487    feature = 2,
1488    tv = 0.1,
1489    max_iterations = 1000,
1490    batch_size = 1,
1491    save_all_best = False,
1492    feature_type = 'grader',
1493    check_eval_data_iteration = 20,
1494    verbose = False  ):
1495    """
1496    Orchestrates the training process for a super-resolution model.
1497
1498    This function handles the entire training loop, including setting up data
1499    generators, defining a composite loss function, automatically balancing loss
1500    weights, iteratively training the model, periodically evaluating performance,
1501    and saving the best-performing model weights.
1502
1503    Parameters
1504    ----------
1505    mdl : tf.keras.Model
1506        The Keras model to be trained.
1507    filenames_train : list of str
1508        List of file paths for the training dataset.
1509    filenames_test : list of str
1510        List of file paths for the validation/testing dataset.
1511    target_patch_size : tuple or list
1512        The dimensions of the high-resolution target patches.
1513    target_patch_size_low : tuple or list
1514        The dimensions of the low-resolution input patches.
1515    output_prefix : str
1516        A prefix for all output files (e.g., model weights, training logs).
1517    n_test : int, optional
1518        The number of validation patches to use for evaluation. Default is 8.
1519    learning_rate : float, optional
1520        The learning rate for the Adam optimizer. Default is 5e-5.
1521    feature_layer : int, optional
1522        The layer index from the feature extractor to use for perceptual loss.
1523        Default is 6.
1524    feature : float, optional
1525        The relative weight of the perceptual (feature) loss term. Default is 2.0.
1526    tv : float, optional
1527        The relative weight of the Total Variation (TV) regularization term.
1528        Default is 0.1.
1529    max_iterations : int, optional
1530        The total number of training iterations to run. Default is 1000.
1531    batch_size : int, optional
1532        The batch size for training. Note: this implementation is optimized for
1533        batch_size=1 and may need adjustment for larger batches. Default is 1.
1534    save_all_best : bool, optional
1535        If True, saves a new model file every time validation loss improves.
1536        If False, overwrites the single best model file. Default is False.
1537    feature_type : str, optional
1538        The type of feature extractor for perceptual loss. Options: 'grader',
1539        'vgg', 'vggrandom'. Default is 'grader'.
1540    check_eval_data_iteration : int, optional
1541        The frequency (in iterations) at which to run validation and save logs.
1542        Default is 20.
1543    verbose : bool, optional
1544        If True, prints detailed progress information. Default is False.
1545
1546    Returns
1547    -------
1548    pd.DataFrame
1549        A DataFrame containing the training history, with columns for training
1550        loss, validation loss, PSNR, and baseline PSNR over iterations.
1551    """
1552    colnames = ['train_loss','test_loss','best','eval_psnr','eval_psnr_lin']
1553    training_path = np.zeros( [ max_iterations, len(colnames) ] )
1554    training_weights = np.zeros( [1,3] )
1555    if verbose:
1556        print("begin get feature extractor " + feature_type)
1557    if feature_type == 'grader':
1558        feature_extractor = get_grader_feature_network( feature_layer )
1559    elif feature_type == 'vggrandom':
1560        with eager_mode():
1561            feature_extractor = pseudo_3d_vgg_features( target_patch_size, feature_layer, pretrained=False )
1562    elif feature_type == 'vgg':
1563        with eager_mode():
1564            feature_extractor = pseudo_3d_vgg_features_unbiased( target_patch_size, feature_layer )
1565    else:
1566        raise Exception("feature type does not exist")
1567    if verbose:
1568        print("begin train generator")
1569    mydatgen = image_generator(
1570        filenames_train,
1571        nPatches=1,
1572        target_patch_size=target_patch_size,
1573        target_patch_size_low=target_patch_size_low,
1574        istest=False , verbose=False)
1575    if verbose:
1576        print("begin test generator")
1577    mydatgenTest = image_generator( filenames_test, nPatches=1,
1578        target_patch_size=target_patch_size,
1579        target_patch_size_low=target_patch_size_low,
1580        istest=True, verbose=True)
1581    patchesResamTeTf, patchesOrigTeTf, patchesUpTeTf = next( mydatgenTest )
1582    if len( patchesOrigTeTf.shape ) == 5:
1583            tdim = 3
1584            myax = [1,2,3,4]
1585    if len( patchesOrigTeTf.shape ) == 4:
1586            tdim = 2
1587            myax = [1,2,3]
1588    if verbose:
1589        print("begin train generator #2 at dim: " + str( tdim))
1590    mydatgenTest = image_generator( filenames_test, nPatches=1,
1591        target_patch_size=target_patch_size,
1592        target_patch_size_low=target_patch_size_low,
1593        istest=True, verbose=True)
1594    patchesResamTeTfB, patchesOrigTeTfB, patchesUpTeTfB = next( mydatgenTest )
1595    for k in range( n_test - 1 ):
1596        mydatgenTest = image_generator( filenames_test, nPatches=1,
1597            target_patch_size=target_patch_size,
1598            target_patch_size_low=target_patch_size_low,
1599            istest=True, verbose=True)
1600        temp0, temp1, temp2 = next( mydatgenTest )
1601        patchesResamTeTfB = tf.concat( [patchesResamTeTfB,temp0],axis=0)
1602        patchesOrigTeTfB = tf.concat( [patchesOrigTeTfB,temp1],axis=0)
1603        patchesUpTeTfB = tf.concat( [patchesUpTeTfB,temp2],axis=0)
1604    if verbose:
1605        print("begin auto_weight_loss")
1606    wts_csv = output_prefix + "_training_weights.csv"
1607    if exists( wts_csv ):
1608        wtsdf = pd.read_csv( wts_csv )
1609        wts = [wtsdf['msq'][0], wtsdf['feat'][0], wtsdf['tv'][0]]
1610        if verbose:
1611            print( "preset weights:" )
1612    else:
1613        with eager_mode():
1614            wts = auto_weight_loss( mdl, feature_extractor, patchesResamTeTf, patchesOrigTeTf,
1615                feature=feature, tv=tv )
1616        for k in range(len(wts)):
1617            training_weights[0,k]=wts[k]
1618        pd.DataFrame(training_weights, columns = ["msq","feat","tv"] ).to_csv( wts_csv )
1619        if verbose:
1620            print( "automatic weights:" )
1621    if verbose:
1622        print( wts )
1623    def my_loss_6( y_true, y_pred, msqwt = wts[0], fw = wts[1], tvwt = wts[2], mybs = batch_size ):
1624        """Composite loss: MSE + Perceptual Loss + Total Variation."""
1625        squared_difference = tf.square(y_true - y_pred)
1626        if len( y_true.shape ) == 5:
1627            tdim = 3
1628            myax = [1,2,3,4]
1629        if len( y_true.shape ) == 4:
1630            tdim = 2
1631            myax = [1,2,3]
1632        msqTerm = tf.reduce_mean(squared_difference, axis=myax)
1633        temp1 = feature_extractor(y_true)
1634        temp2 = feature_extractor(y_pred)
1635        feature_difference = tf.square(temp1-temp2)
1636        featureTerm = tf.reduce_mean(feature_difference, axis=myax)
1637        loss = msqTerm * msqwt + featureTerm * fw
1638        mytv = tf.cast( 0.0, 'float32')
1639        # mybs =  int( y_pred.shape[0] ) --- should work but ... ?
1640        if tdim == 3:
1641            for k in range( mybs ): # BUG not sure why myr fails .... might be old TF version
1642                sqzd = y_pred[k,:,:,:,:]
1643                mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) ) * tvwt
1644        if tdim == 2:
1645            mytv = tf.reduce_mean( tf.image.total_variation( y_pred ) ) * tvwt
1646        return( loss + mytv )
1647    if verbose:
1648        print("begin model compilation")
1649    opt = tf.keras.optimizers.Adam( learning_rate=learning_rate )
1650    mdl.compile(optimizer=opt, loss=my_loss_6)
1651    # set up some parameters for tracking performance
1652    bestValLoss=1e12
1653    bestSSIM=0.0
1654    bestQC0 = -1000
1655    bestQC1 = -1000
1656    if verbose:
1657        print( "begin training", flush=True  )
1658    for myrs in range( max_iterations ):
1659        tracker = mdl.fit( mydatgen,  epochs=2, steps_per_epoch=4, verbose=1,
1660            validation_data=(patchesResamTeTf,patchesOrigTeTf) )
1661        training_path[myrs,0]=tracker.history['loss'][0]
1662        training_path[myrs,1]=tracker.history['val_loss'][0]
1663        training_path[myrs,2]=0
1664        print( "ntrain: " + str(myrs) + " loss " + str( tracker.history['loss'][0] ) + ' val-loss ' + str(tracker.history['val_loss'][0]), flush=True  )
1665        if myrs % check_eval_data_iteration == 0:
1666            with tf.device("/cpu:0"):
1667                myofn = output_prefix + "_best_mdl.keras"
1668                if save_all_best:
1669                    myofn = output_prefix + "_" + str(myrs)+ "_mdl.keras"
1670                tester = mdl.evaluate( patchesResamTeTfB, patchesOrigTeTfB )
1671                if ( tester < bestValLoss ):
1672                    print("MyIT " + str( myrs ) + " IS BEST!! " + str( tester ) + myofn, flush=True )
1673                    bestValLoss = tester
1674                    tf.keras.models.save_model( mdl, myofn )
1675                    training_path[myrs,2]=1
1676                pp = mdl.predict( patchesResamTeTfB, batch_size = 1 )
1677                myssimSR = tf.image.psnr( pp * 220, patchesOrigTeTfB* 220, max_val=255 )
1678                myssimSR = tf.reduce_mean( myssimSR ).numpy()
1679                myssimBI = tf.image.psnr( patchesUpTeTfB * 220, patchesOrigTeTfB* 220, max_val=255 )
1680                myssimBI = tf.reduce_mean( myssimBI ).numpy()
1681                print( myofn + " : " + "PSNR Lin: " + str( myssimBI ) + " SR: " + str( myssimSR ), flush=True  )
1682                training_path[myrs,3]=myssimSR # psnr
1683                training_path[myrs,4]=myssimBI # psnrlin
1684                pd.DataFrame(training_path, columns = colnames ).to_csv( output_prefix + "_training.csv" )
1685    training_path = pd.DataFrame(training_path, columns = colnames )
1686    return training_path
1687
1688
1689def binary_dice_loss(y_true, y_pred):
1690    """
1691    Computes the Dice loss for binary segmentation tasks.
1692
1693    The Dice coefficient is a common metric for comparing the overlap of two samples.
1694    This loss function computes `1 - DiceCoefficient`, making it suitable for
1695    minimization during training. A smoothing factor is added to avoid division
1696    by zero when both the prediction and the ground truth are empty.
1697
1698    Parameters
1699    ----------
1700    y_true : tf.Tensor
1701        The ground truth binary segmentation mask. Values should be 0 or 1.
1702    y_pred : tf.Tensor
1703        The predicted binary segmentation mask, typically with values in [0, 1]
1704        from a sigmoid activation.
1705
1706    Returns
1707    -------
1708    tf.Tensor
1709        A scalar tensor representing the Dice loss. The value ranges from -1 (perfect
1710        match) to 0 (no overlap), though it's typically used as `1 - dice_coeff`
1711        or just `-dice_coeff` (as here).
1712    """
1713    smoothing_factor = 1e-4
1714    K = tf.keras.backend
1715    y_true_f = K.flatten(y_true)
1716    y_pred_f = K.flatten(y_pred)
1717    intersection = K.sum(y_true_f * y_pred_f)
1718    # This is -1 * Dice Similarity Coefficient
1719    return -1 * (2 * intersection + smoothing_factor)/(K.sum(y_true_f) +
1720            K.sum(y_pred_f) + smoothing_factor)
1721
1722def train_seg(
1723    mdl,
1724    filenames_train,
1725    filenames_test,
1726    target_patch_size,
1727    target_patch_size_low,
1728    output_prefix,
1729    n_test = 8,
1730    learning_rate=5e-5,
1731    feature_layer = 6,
1732    feature = 2,
1733    tv = 0.1,
1734    dice = 0.5,
1735    max_iterations = 1000,
1736    batch_size = 1,
1737    save_all_best = False,
1738    feature_type = 'grader',
1739    check_eval_data_iteration = 20,
1740    verbose = False  ):
1741    """
1742    Orchestrates training for a multi-task image and segmentation SR model.
1743
1744    This function extends the `train` function to handle models that predict
1745    both a super-resolved image and a super-resolved segmentation mask. It uses
1746    a four-component composite loss: MSE (for image), a perceptual loss (for
1747    image), Total Variation (for image), and Dice loss (for segmentation).
1748
1749    Parameters
1750    ----------
1751    mdl : tf.keras.Model
1752        The 2-channel Keras model to be trained.
1753    filenames_train : list of str
1754        List of file paths for the training dataset.
1755    filenames_test : list of str
1756        List of file paths for the validation/testing dataset.
1757    target_patch_size : tuple or list
1758        The dimensions of the high-resolution target patches.
1759    target_patch_size_low : tuple or list
1760        The dimensions of the low-resolution input patches.
1761    output_prefix : str
1762        A prefix for all output files.
1763    n_test : int, optional
1764        Number of validation patches for evaluation. Default is 8.
1765    learning_rate : float, optional
1766        Learning rate for the Adam optimizer. Default is 5e-5.
1767    feature_layer : int, optional
1768        Layer from the feature extractor for perceptual loss. Default is 6.
1769    feature : float, optional
1770        Relative weight of the perceptual loss term. Default is 2.0.
1771    tv : float, optional
1772        Relative weight of the Total Variation regularization term. Default is 0.1.
1773    dice : float, optional
1774        Relative weight of the Dice loss term for the segmentation mask.
1775        Default is 0.5.
1776    max_iterations : int, optional
1777        Total number of training iterations. Default is 1000.
1778    batch_size : int, optional
1779        The batch size for training. Default is 1.
1780    save_all_best : bool, optional
1781        If True, saves all models that improve validation loss. Default is False.
1782    feature_type : str, optional
1783        Type of feature extractor for perceptual loss. Default is 'grader'.
1784    check_eval_data_iteration : int, optional
1785        Frequency (in iterations) for running validation. Default is 20.
1786    verbose : bool, optional
1787        If True, prints detailed progress information. Default is False.
1788
1789    Returns
1790    -------
1791    pd.DataFrame
1792        A DataFrame containing the training history, including columns for losses
1793        and evaluation metrics like PSNR and Dice score.
1794
1795    See Also
1796    --------
1797    train : The training function for single-task (intensity-only) models.
1798    """
1799    colnames = ['train_loss','test_loss','best','eval_psnr','eval_psnr_lin','eval_msq','eval_dice']
1800    training_path = np.zeros( [ max_iterations, len(colnames) ] )
1801    training_weights = np.zeros( [1,4] )
1802    if verbose:
1803        print("begin get feature extractor")
1804    if feature_type == 'grader':
1805        feature_extractor = get_grader_feature_network( feature_layer )
1806    elif feature_type == 'vggrandom':
1807        feature_extractor = pseudo_3d_vgg_features( target_patch_size, feature_layer, pretrained=False )
1808    else:
1809        feature_extractor = pseudo_3d_vgg_features_unbiased( target_patch_size, feature_layer  )
1810    if verbose:
1811        print("begin train generator")
1812    mydatgen = seg_generator(
1813        filenames_train,
1814        nPatches=1,
1815        target_patch_size=target_patch_size,
1816        target_patch_size_low=target_patch_size_low,
1817        istest=False , verbose=False)
1818    if verbose:
1819        print("begin test generator")
1820    mydatgenTest = seg_generator( filenames_test, nPatches=1,
1821        target_patch_size=target_patch_size,
1822        target_patch_size_low=target_patch_size_low,
1823        istest=True, verbose=True)
1824    patchesResamTeTf, patchesOrigTeTf, patchesUpTeTf = next( mydatgenTest )
1825    if len( patchesOrigTeTf.shape ) == 5:
1826            tdim = 3
1827            myax = [1,2,3,4]
1828    if len( patchesOrigTeTf.shape ) == 4:
1829            tdim = 2
1830            myax = [1,2,3]
1831    if verbose:
1832        print("begin train generator #2 at dim: " + str( tdim))
1833    mydatgenTest = seg_generator( filenames_test, nPatches=1,
1834        target_patch_size=target_patch_size,
1835        target_patch_size_low=target_patch_size_low,
1836        istest=True, verbose=True)
1837    patchesResamTeTfB, patchesOrigTeTfB, patchesUpTeTfB = next( mydatgenTest )
1838    for k in range( n_test - 1 ):
1839        mydatgenTest = seg_generator( filenames_test, nPatches=1,
1840            target_patch_size=target_patch_size,
1841            target_patch_size_low=target_patch_size_low,
1842            istest=True, verbose=True)
1843        temp0, temp1, temp2 = next( mydatgenTest )
1844        patchesResamTeTfB = tf.concat( [patchesResamTeTfB,temp0],axis=0)
1845        patchesOrigTeTfB = tf.concat( [patchesOrigTeTfB,temp1],axis=0)
1846        patchesUpTeTfB = tf.concat( [patchesUpTeTfB,temp2],axis=0)
1847    if verbose:
1848        print("begin auto_weight_loss_seg")
1849    wts_csv = output_prefix + "_training_weights.csv"
1850    if exists( wts_csv ):
1851        wtsdf = pd.read_csv( wts_csv )
1852        wts = [wtsdf['msq'][0], wtsdf['feat'][0], wtsdf['tv'][0], wtsdf['dice'][0]]
1853        if verbose:
1854            print( "preset weights:" )
1855    else:
1856        wts = auto_weight_loss_seg( mdl, feature_extractor, patchesResamTeTf, patchesOrigTeTf,
1857            feature=feature, tv=tv, dice=dice )
1858        for k in range(len(wts)):
1859            training_weights[0,k]=wts[k]
1860        pd.DataFrame(training_weights, columns = ["msq","feat","tv","dice"] ).to_csv( wts_csv )
1861        if verbose:
1862            print( "automatic weights:" )
1863    if verbose:
1864        print( wts )
1865    def my_loss_6( y_true, y_pred, msqwt = wts[0], fw = wts[1], tvwt = wts[2], dicewt=wts[3], mybs = batch_size ):
1866        """Composite loss: MSE + Perceptual + TV + Dice."""
1867        if len( y_true.shape ) == 5:
1868            tdim = 3
1869            myax = [1,2,3,4]
1870        if len( y_true.shape ) == 4:
1871            tdim = 2
1872            myax = [1,2,3]
1873        y_intensity = tf.split( y_true, 2, axis=tdim+1 )[0]
1874        y_seg = tf.split( y_true, 2, axis=tdim+1 )[1]
1875        y_intensity_p = tf.split( y_pred, 2, axis=tdim+1 )[0]
1876        y_seg_p = tf.split( y_pred, 2, axis=tdim+1 )[1]
1877        squared_difference = tf.square(y_intensity - y_intensity_p)
1878        msqTerm = tf.reduce_mean(squared_difference, axis=myax)
1879        temp1 = feature_extractor(y_intensity)
1880        temp2 = feature_extractor(y_intensity_p)
1881        feature_difference = tf.square(temp1-temp2)
1882        featureTerm = tf.reduce_mean(feature_difference, axis=myax)
1883        loss = msqTerm * msqwt + featureTerm * fw
1884        mytv = tf.cast( 0.0, 'float32')
1885        if tdim == 3:
1886            for k in range( mybs ): # BUG not sure why myr fails .... might be old TF version
1887                sqzd = y_pred[k,:,:,:,0]
1888                mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) ) * tvwt
1889        if tdim == 2:
1890            mytv = tf.reduce_mean( tf.image.total_variation( y_pred[:,:,:,0] ) ) * tvwt
1891        dicer = tf.reduce_mean( dicewt * binary_dice_loss( y_seg, y_seg_p ) )
1892        return( loss + mytv + dicer )
1893    if verbose:
1894        print("begin model compilation")
1895    opt = tf.keras.optimizers.Adam( learning_rate=learning_rate )
1896    mdl.compile(optimizer=opt, loss=my_loss_6)
1897    # set up some parameters for tracking performance
1898    bestValLoss=1e12
1899    bestSSIM=0.0
1900    bestQC0 = -1000
1901    bestQC1 = -1000
1902    if verbose:
1903        print( "begin training", flush=True  )
1904    for myrs in range( max_iterations ):
1905        tracker = mdl.fit( mydatgen,  epochs=2, steps_per_epoch=4, verbose=1,
1906            validation_data=(patchesResamTeTf,patchesOrigTeTf) )
1907        training_path[myrs,0]=tracker.history['loss'][0]
1908        training_path[myrs,1]=tracker.history['val_loss'][0]
1909        training_path[myrs,2]=0
1910        print( "ntrain: " + str(myrs) + " loss " + str( tracker.history['loss'][0] ) + ' val-loss ' + str(tracker.history['val_loss'][0]), flush=True  )
1911        if myrs % check_eval_data_iteration == 0:
1912            with tf.device("/cpu:0"):
1913                myofn = output_prefix + "_best_mdl.keras"
1914                if save_all_best:
1915                    myofn = output_prefix + "_" + str(myrs)+ "_mdl.keras"
1916                tester = mdl.evaluate( patchesResamTeTfB, patchesOrigTeTfB )
1917                if ( tester < bestValLoss ):
1918                    print("MyIT " + str( myrs ) + " IS BEST!! " + str( tester ) + myofn, flush=True )
1919                    bestValLoss = tester
1920                    tf.keras.models.save_model( mdl, myofn )
1921                    training_path[myrs,2]=1
1922                pp = mdl.predict( patchesResamTeTfB, batch_size = 1 )
1923                pp = tf.split( pp, 2, axis=tdim+1 )
1924                y_orig = tf.split( patchesOrigTeTfB, 2, axis=tdim+1 )
1925                y_up = tf.split( patchesUpTeTfB, 2, axis=tdim+1 )
1926                myssimSR = tf.image.psnr( pp[0] * 220, y_orig[0]* 220, max_val=255 )
1927                myssimSR = tf.reduce_mean( myssimSR ).numpy()
1928                myssimBI = tf.image.psnr( y_up[0] * 220, y_orig[0]* 220, max_val=255 )
1929                myssimBI = tf.reduce_mean( myssimBI ).numpy()
1930                squared_difference = tf.square(y_orig[0] - pp[0])
1931                msqTerm = tf.reduce_mean(squared_difference).numpy()
1932                dicer = binary_dice_loss( y_orig[1], pp[1] )
1933                dicer = tf.reduce_mean( dicer ).numpy()
1934                print( myofn + " : " + "PSNR Lin: " + str( myssimBI ) + " SR: " + str( myssimSR ) + " MSQ: " + str(msqTerm) + " DICE: " + str(dicer), flush=True  )
1935                training_path[myrs,3]=myssimSR # psnr
1936                training_path[myrs,4]=myssimBI # psnrlin
1937                training_path[myrs,5]=msqTerm # msq
1938                training_path[myrs,6]=dicer # dice
1939                pd.DataFrame(training_path, columns = colnames ).to_csv( output_prefix + "_training.csv" )
1940    training_path = pd.DataFrame(training_path, columns = colnames )
1941    return training_path
1942
1943
1944def read_srmodel(srfilename, custom_objects=None):
1945    """
1946    Load a super-resolution model (h5, .keras, or SavedModel format),
1947    and determine its upsampling factor.
1948
1949    Parameters
1950    ----------
1951    srfilename : str
1952        Path to the model file (.h5, .keras, or a SavedModel folder).
1953    custom_objects : dict, optional
1954        Dictionary of custom objects used in the model (e.g. {'TFOpLambda': tf.keras.layers.Lambda(...)})
1955
1956    Returns
1957    -------
1958    model : tf.keras.Model
1959        The loaded model.
1960    upsampling_factor : list of int
1961        List describing the upsampling factor:
1962        - For 3D input: [x_up, y_up, z_up, channels]
1963        - For 2D input: [x_up, y_up, channels]
1964
1965    Example
1966    -------
1967    >>> mdl, up = read_srmodel("mymodel.keras")
1968    >>> mdl, up = read_srmodel("my_weights.h5", custom_objects={"TFOpLambda": tf.keras.layers.Lambda(tf.identity)})
1969    """
1970
1971    # Expand path and detect format
1972    srfilename = os.path.expanduser(srfilename)
1973    ext = os.path.splitext(srfilename)[1].lower()
1974
1975    if os.path.isdir(srfilename):
1976        # SavedModel directory
1977        model = tf.keras.models.load_model(srfilename, custom_objects=custom_objects, compile=False)
1978    elif ext in ['.h5', '.keras']:
1979        model = tf.keras.models.load_model(srfilename, custom_objects=custom_objects, compile=False)
1980    else:
1981        raise ValueError(f"Unsupported model format: {ext}")
1982
1983    # Determine channel index
1984    input_shape = model.input_shape
1985    if isinstance(input_shape, list):
1986        input_shape = input_shape[0]
1987    chanindex = 3 if len(input_shape) == 4 else 4
1988    nchan = int(input_shape[chanindex])
1989
1990    # Run dummy input to compute upsampling factor
1991    try:
1992        if len(input_shape) == 5:  # 3D
1993            dummy_input = np.zeros([1, 8, 8, 8, nchan])
1994        else:  # 2D
1995            dummy_input = np.zeros([1, 8, 8, nchan])
1996
1997        # Handle named inputs if necessary
1998        try:
1999            output = model(dummy_input)
2000        except Exception:
2001            output = model({model.input_names[0]: dummy_input})
2002
2003        outshp = output.shape
2004        if len(input_shape) == 5:
2005            return model, [int(outshp[1]/8), int(outshp[2]/8), int(outshp[3]/8), nchan]
2006        else:
2007            return model, [int(outshp[1]/8), int(outshp[2]/8), nchan]
2008
2009    except Exception as e:
2010        raise RuntimeError(f"Could not infer upsampling factor. Error: {e}")
2011
2012
2013def simulate_image( shaper=[32,32,32], n_levels=10, multiply=False ):
2014    """
2015    generate an image of given shape and number of levels
2016
2017    Arguments
2018    ---------
2019    shaper : [x,y,z] or [x,y]
2020
2021    n_levels : int
2022
2023    multiply : boolean
2024
2025    Returns
2026    -------
2027
2028    ants.image
2029
2030    """
2031    img = ants.from_numpy( np.random.normal( 0, 1.0, size=shaper ) ) * 0
2032    for k in range(n_levels):
2033        temp = ants.from_numpy( np.random.normal( 0, 1.0, size=shaper ) )
2034        temp = ants.smooth_image( temp, n_levels )
2035        temp = ants.threshold_image( temp, "Otsu", 1 )
2036        if multiply:
2037            temp = temp * k
2038        img = img + temp
2039    return img
2040
2041
2042def optimize_upsampling_shape( spacing, modality='T1', roundit=False, verbose=False ):
2043    """
2044    Compute the optimal upsampling shape string (e.g., '2x2x2') based on image voxel spacing
2045    and imaging modality. This output is used to select an appropriate pretrained 
2046    super-resolution model filename.
2047
2048    Parameters
2049    ----------
2050    spacing : sequence of float
2051        Voxel spacing (physical size per voxel in mm) from the input image.
2052        Typically obtained from `ants.get_spacing(image)`.
2053
2054    modality : str, optional
2055        Imaging modality. Affects resolution thresholds:
2056        - 'T1' : anatomical MRI (default minimum spacing: 0.35 mm)
2057        - 'DTI' : diffusion MRI (default minimum spacing: 1.0 mm)
2058        - 'NM' : nuclear medicine (e.g., PET/SPECT, minimum spacing: 0.25 mm)
2059
2060    roundit : bool, optional
2061        If True, uses rounded integer ratios for the upsampling shape.
2062        Otherwise, uses floor division with constraints.
2063
2064    verbose : bool, optional
2065        If True, prints detailed internal values and logic.
2066
2067    Returns
2068    -------
2069    str
2070        Optimal upsampling shape string in the form 'AxBxC',
2071        e.g., '2x2x2', '4x4x2'.
2072
2073    Notes
2074    -----
2075    - The function prevents upsampling ratios that would result in '1x1x1'
2076      by defaulting to '2x2x2'.
2077    - It also avoids uncommon ratios like '5' by rounding to the nearest valid option.
2078    - The returned string is commonly used to populate a model filename template:
2079      
2080      Example:
2081          >>> bestup = optimize_upsampling_shape(ants.get_spacing(t1_img), modality='T1')
2082          >>> model = re.sub('bestup', bestup, 'siq_smallshort_train_bestup_1chan.keras')
2083    """
2084    minspc = min( list( spacing ) )
2085    maxspc = max( list( spacing ) )
2086    ratio = maxspc/minspc
2087    if ratio == 1.0:
2088        ratio = 0.5
2089    roundratio = np.round( ratio )
2090    tarshaperaw = []
2091    tarshape = []
2092    tarshaperound = []
2093    for k in range( len( spacing ) ):
2094        locrat = spacing[k]/minspc
2095        newspc = spacing[k] * roundratio
2096        tarshaperaw.append( locrat )
2097        if modality == "NM":
2098            if verbose:
2099                print("Using minspacing: 0.25")
2100            if newspc < 0.25 :
2101                locrat = spacing[k]/0.25
2102        elif modality == "DTI":
2103            if verbose:
2104                print("Using minspacing: 1.0")
2105            if newspc < 1.0 :
2106                locrat = spacing[k]/1.0
2107        else: # assume T1
2108            if verbose:
2109                print("Using minspacing: 0.35")
2110            if newspc < 0.35 :
2111                locrat = spacing[k]/0.35
2112        myint = int( locrat )
2113        if ( myint == 0 ):
2114            myint = 1
2115        if myint == 5:
2116            myint = 4
2117        if ( myint > 6 ):
2118            myint = 6
2119        tarshape.append( str( myint ) )
2120        tarshaperound.append( str( int(np.round( locrat )) ) )
2121    if verbose:
2122        print("before emendation:")
2123        print( tarshaperaw )
2124        print( tarshaperound )
2125        print( tarshape )
2126    allone = True
2127    if roundit:
2128        tarshape = tarshaperound
2129    for k in range( len( tarshape ) ):
2130        if tarshape[k] != "1":
2131            allone=False
2132    if allone:
2133        tarshape = ["2","2","2"] # default
2134    return "x".join(tarshape)
2135
2136def compare_models( model_filenames, img, n_classes=3,
2137    poly_order='hist',
2138    identifier=None, noise_sd=0.1,verbose=False ):
2139    """
2140    Evaluates and compares the performance of multiple super-resolution models on a given image.
2141
2142    This function provides a standardized way to benchmark SR models. For each model,
2143    it performs the following steps:
2144    1. Loads the model and determines its upsampling factor.
2145    2. Downsamples the high-resolution input image (`img`) to create a low-resolution
2146       input, simulating a real-world scenario.
2147    3. Adds Gaussian noise to the low-resolution input to test for robustness.
2148    4. Runs inference using the model to generate a super-resolved output.
2149    5. Generates a baseline output by upsampling the low-res input with linear interpolation.
2150    6. Calculates PSNR and SSIM metrics comparing both the model's output and the
2151       baseline against the original high-resolution image.
2152    7. If a dual-channel (image + segmentation) model is detected, it also calculates
2153       Dice scores for segmentation performance.
2154    8. Aggregates all results into a pandas DataFrame for easy comparison.
2155
2156    Parameters
2157    ----------
2158    model_filenames : list of str
2159        A list of file paths to the Keras models (.h5, .keras) to be compared.
2160    img : ants.ANTsImage
2161        The high-resolution ground truth image. This image will be downsampled to
2162        create the input for the models.
2163    n_classes : int, optional
2164        The number of classes for Otsu's thresholding when auto-generating a
2165        segmentation for evaluating dual-channel models. Default is 3.
2166    poly_order : str or int, optional
2167        Method for intensity matching between the SR output and the reference.
2168        Options: 'hist' for histogram matching (default), an integer for
2169        polynomial regression, or None to disable.
2170    identifier : str, optional
2171        A custom identifier for the output DataFrame. If None, it is inferred
2172        from the model filename. Default is None.
2173    noise_sd : float, optional
2174        Standard deviation of the additive Gaussian noise applied to the
2175        downsampled image before inference. Default is 0.1.
2176    verbose : bool, optional
2177        If True, prints detailed progress and intermediate values. Default is False.
2178
2179    Returns
2180    -------
2181    pd.DataFrame
2182        A DataFrame where each row corresponds to a model. Columns contain evaluation
2183        metrics (PSNR.SR, SSIM.SR, DICE.SR), baseline metrics (PSNR.LIN, SSIM.LIN,
2184        DICE.NN), and metadata.
2185
2186    Notes
2187    -----
2188    When evaluating a 2-channel (segmentation) model, the primary metric for the
2189    segmentation task is the Dice score (`DICE.SR`). The intensity metrics (PSNR, SSIM)
2190    are still computed on the first channel.
2191    """
2192    padding=4
2193    mydf = pd.DataFrame()
2194    for k in range( len( model_filenames ) ):
2195        srmdl, upshape = read_srmodel( model_filenames[k] )
2196        if verbose:
2197            print( model_filenames[k] )
2198            print( upshape )
2199        tarshape = []
2200        inspc = ants.get_spacing(img)
2201        for j in range(len(img.shape)):
2202            tarshape.append( float(upshape[j]) * inspc[j] )
2203        # uses linear interp
2204        dimg=ants.resample_image( img, tarshape, use_voxels=False, interp_type=0 )
2205        dimg = ants.add_noise_to_image( dimg,'additivegaussian', [0,noise_sd] )
2206        import math
2207        dicesr=math.nan
2208        dicenn=math.nan
2209        if upshape[3] == 2:
2210            seghigh = ants.threshold_image( img,"Otsu",n_classes)
2211            seglow = ants.resample_image( seghigh, tarshape, use_voxels=False, interp_type=1 )
2212            dimgup=inference( dimg, srmdl, segmentation = seglow, poly_order=poly_order, verbose=verbose )
2213            dimgupseg = dimgup['super_resolution_segmentation']
2214            dimgup = dimgup['super_resolution']
2215            segblock = ants.resample_image_to_target( seghigh, dimgupseg, interp_type='nearestNeighbor'  )
2216            segimgnn = ants.resample_image_to_target( seglow, dimgupseg, interp_type='nearestNeighbor' )
2217            segblock[ dimgupseg == 0 ] = 0
2218            segimgnn[ dimgupseg == 0 ] = 0
2219            dicenn = ants.label_overlap_measures(segblock, segimgnn)['MeanOverlap'][0]
2220            dicesr = ants.label_overlap_measures(segblock, dimgupseg)['MeanOverlap'][0]
2221        else:
2222            dimgup=inference( dimg, srmdl, poly_order=poly_order, verbose=verbose )
2223        dimglin = ants.resample_image_to_target( dimg, dimgup, interp_type='linear' )
2224        imgblock = ants.resample_image_to_target( img, dimgup, interp_type='linear'  )
2225        dimgup[ imgblock == 0.0 ]=0.0
2226        dimglin[ imgblock == 0.0 ]=0.0
2227        padder = []
2228        dimwarning=False
2229        for jj in range(img.dimension):
2230            padder.append( padding )
2231            if img.shape[jj] != imgblock.shape[jj]:
2232                dimwarning=True
2233        if dimwarning:
2234            print("NOTE: dimensions of downsampled to upsampled image do not match!!!")
2235            print("we force them to match but this suggests results may not be reliable.")
2236        temp = os.path.basename( model_filenames[k] )
2237        temp = re.sub( "siq_default_sisr_", "", temp )
2238        temp = re.sub( "_best_mdl.keras", "", temp )
2239        temp = re.sub( "_best_mdl.h5", "", temp )
2240        if verbose and dimwarning:
2241            print( "original img shape" )
2242            print( img.shape )
2243            print( "resampled img shape" )
2244            print( imgblock.shape )
2245        a=[]
2246        imgshape = []
2247        for aa in range(len(upshape)):
2248            a.append( str(upshape[aa]) )
2249            if aa < len(imgblock.shape):
2250                imgshape.append( str( imgblock.shape[aa] ) )
2251        if identifier is None:
2252            identifier=temp
2253        mydict = {
2254            "identifier":identifier,
2255            "imgshape":"x".join(imgshape),
2256            "mdl": temp,
2257            "mdlshape":"x".join(a),
2258            "PSNR.LIN": antspynet.psnr( imgblock, dimglin ),
2259            "PSNR.SR": antspynet.psnr( imgblock, dimgup ),
2260            "SSIM.LIN": antspynet.ssim( imgblock, dimglin ),
2261            "SSIM.SR": antspynet.ssim( imgblock, dimgup ),
2262            "DICE.NN": dicenn,
2263            "DICE.SR": dicesr,
2264            "dimwarning": dimwarning }
2265        if verbose:
2266            print( mydict )
2267        temp = pd.DataFrame.from_records( [mydict], index=[0] )
2268        mydf = pd.concat( [mydf,temp], axis=0 )
2269        # end loop
2270    return mydf
2271
2272
2273
2274
2275def region_wise_super_resolution(image, mask, super_res_model, dilation_amount=4, verbose=False):
2276    """
2277    Apply super-resolution model to each labeled region in the mask independently.
2278
2279    Arguments
2280    ---------
2281    image : ANTsImage
2282        Input image.
2283
2284    mask : ANTsImage
2285        Integer-labeled segmentation mask with non-zero regions to upsample.
2286
2287    super_res_model : tf.keras.Model
2288        Trained super-resolution model.
2289
2290    dilation_amount : int
2291        Number of morphological dilations applied to each label region before cropping.
2292
2293    verbose : bool
2294        If True, print detailed status.
2295
2296    Returns
2297    -------
2298    ANTsImage : Full-size super-resolved image with per-label inference and stitching.
2299    """
2300    import ants
2301    import numpy as np
2302    from antspynet import apply_super_resolution_model_to_image
2303
2304    upFactor = []
2305    input_shape = super_res_model.inputs[0].shape
2306    test_shape = [1, 8, 8, 1] if len(input_shape) == 4 else [1, 8, 8, 8, 1]
2307    test_input = np.zeros(test_shape, dtype=np.float32)
2308    test_output = super_res_model(test_input)
2309
2310    for k in range(len(test_shape) - 2):  # ignore batch + channel
2311        upFactor.append(int(test_output.shape[k + 1] / test_input.shape[k + 1]))
2312
2313    original_size = mask.shape  # e.g., (x, y, z)
2314    new_size = tuple(int(s * f) for s, f in zip(original_size, upFactor))
2315    upsampled_mask = ants.resample_image(mask, new_size, use_voxels=True, interp_type=1)
2316    upsampled_image = ants.resample_image(image, new_size, use_voxels=True, interp_type=0)
2317
2318    unique_labels = list(np.unique(upsampled_mask.numpy()))
2319    if 0 in unique_labels:
2320        unique_labels.remove(0)
2321
2322    outimg = ants.image_clone(upsampled_image)
2323
2324    for lab in unique_labels:
2325        if verbose:
2326            print(f"Processing label: {lab}")
2327        regionmask = ants.threshold_image(mask, lab, lab).iMath("MD", dilation_amount)
2328        cropped = ants.crop_image(image, regionmask)
2329        if cropped.shape[0] == 0:
2330            continue
2331        subimgsr = apply_super_resolution_model_to_image(
2332            cropped, super_res_model, target_range=[0, 1], verbose=verbose
2333        )
2334        stitched = ants.decrop_image(subimgsr, outimg)
2335        outimg[upsampled_mask == lab] = stitched[upsampled_mask == lab]
2336
2337    return outimg
2338
2339
2340def region_wise_super_resolution_blended(image, mask, super_res_model, dilation_amount=4, verbose=False):
2341    """
2342    Apply super-resolution model to labeled regions with smooth blending to minimize stitching artifacts.
2343
2344    This version uses a weighted-averaging scheme based on distance transforms
2345    to create seamless transitions between super-resolved regions and the background.
2346
2347    Arguments
2348    ---------
2349    image : ANTsImage
2350        Input low-resolution image.
2351
2352    mask : ANTsImage
2353        Integer-labeled segmentation mask.
2354
2355    super_res_model : tf.keras.Model
2356        Trained super-resolution model.
2357
2358    dilation_amount : int
2359        Number of morphological dilations applied to each label region before cropping.
2360        This provides context to the SR model.
2361
2362    verbose : bool
2363        If True, print detailed status.
2364
2365    Returns
2366    -------
2367    ANTsImage : Full-size, super-resolved image with seamless blending.
2368    """
2369    import ants
2370    import numpy as np
2371    from antspynet import apply_super_resolution_model_to_image
2372    epsilon32 = np.finfo(np.float32).eps
2373    normalize_weight_maps = True  # Default behavior to normalize weight maps
2374    # --- Step 1: Determine upsampling factor and prepare initial images ---
2375    upFactor = []
2376    input_shape = super_res_model.inputs[0].shape
2377    test_shape = [1, 8, 8, 1] if len(input_shape) == 4 else [1, 8, 8, 8, 1]
2378    test_input = np.zeros(test_shape, dtype=np.float32)
2379    test_output = super_res_model(test_input)
2380    for k in range(len(test_shape) - 2):
2381        upFactor.append(int(test_output.shape[k + 1] / test_input.shape[k + 1]))
2382
2383    original_size = image.shape
2384    new_size = tuple(int(s * f) for s, f in zip(original_size, upFactor))
2385
2386    # The initial upsampled image will serve as our background
2387    background_sr_image = ants.resample_image(image, new_size, use_voxels=True, interp_type=0)
2388
2389    # --- Step 2: Initialize accumulator and weight sum canvases ---
2390    # These must be float type for accumulation
2391    accumulator = ants.image_clone(background_sr_image).astype('float32') * 0.0
2392    weight_sum = ants.image_clone(accumulator)
2393
2394    unique_labels = [l for l in np.unique(mask.numpy()) if l != 0]
2395
2396    for lab in unique_labels:
2397        if verbose:
2398            print(f"Blending label: {lab}")
2399
2400        # --- Step 3: Super-resolve a dilated patch (provides context to the model) ---
2401        region_mask_dilated = ants.threshold_image(mask, lab, lab).iMath("MD", dilation_amount)
2402        cropped_lowres = ants.crop_image(image, region_mask_dilated)
2403        if cropped_lowres.shape[0] == 0:
2404            continue
2405            
2406        # Apply the model to the cropped low-res patch
2407        sr_patch = apply_super_resolution_model_to_image(
2408            cropped_lowres, super_res_model, target_range=[0, 1]
2409        )
2410        
2411        # Place the super-resolved patch back onto a full-sized canvas
2412        sr_patch_full_size = ants.decrop_image(sr_patch, accumulator)
2413
2414        # --- Step 4: Create a smooth weight map for this region ---
2415        # We use the *non-dilated* mask for the weight map to ensure a sharp focus on the target region.
2416        region_mask_original = ants.threshold_image(mask, lab, lab)
2417        
2418        # Resample the original region mask to the high-res grid
2419        weight_map = ants.resample_image(region_mask_original, new_size, use_voxels=True, interp_type=0)
2420        weight_map = ants.smooth_image(weight_map, sigma=2.0,
2421                                        sigma_in_physical_coordinates=False)
2422        if normalize_weight_maps:
2423            weight_map = ants.iMath(weight_map, "Normalize")
2424        # --- Step 5: Accumulate the weighted values and the weights themselves ---
2425        accumulator += sr_patch_full_size * weight_map
2426        weight_sum += weight_map
2427
2428    # --- Step 6: Final Combination ---
2429    # Normalize the accumulator by the total weight at each pixel
2430    weight_sum_np = weight_sum.numpy()
2431    accumulator_np = accumulator.numpy()
2432    
2433    # Create a mask of pixels where blending occurred
2434    blended_mask = weight_sum_np > 0.0 # Use a small epsilon for float safety
2435
2436    # Start with the original upsampled image as the base
2437    final_image_np = background_sr_image.numpy()
2438    
2439    # Perform the weighted average only where weights are non-zero
2440    final_image_np[blended_mask] = accumulator_np[blended_mask] / weight_sum_np[blended_mask]
2441    
2442    # Re-insert any non-blended background regions that were processed
2443    # This handles cases where regions overlap; the weighted average takes care of it.
2444    
2445    return ants.from_numpy(final_image_np, origin=background_sr_image.origin, 
2446                           spacing=background_sr_image.spacing, direction=background_sr_image.direction)
2447     
2448
2449def inference(
2450    image,
2451    mdl,
2452    truncation=None,
2453    segmentation=None,
2454    target_range=[1, 0],
2455    poly_order='hist',
2456    dilation_amount=0,
2457    verbose=False):
2458    """
2459    Perform super-resolution inference on an input image, optionally guided by segmentation.
2460
2461    This function uses a trained deep learning model to enhance the resolution of a medical image.
2462    It optionally applies label-wise inference if a segmentation mask is provided.
2463
2464    Parameters
2465    ----------
2466    image : ants.ANTsImage
2467        Input image to be super-resolved.
2468
2469    mdl : keras.Model
2470        Trained super-resolution model, typically from ANTsPyNet.
2471
2472    truncation : tuple or list of float, optional
2473        Percentile values (e.g., [0.01, 0.99]) for intensity truncation before model input.
2474        If None, no truncation is applied.
2475
2476    segmentation : ants.ANTsImage, optional
2477        A labeled segmentation mask. If provided, super-resolution is performed per label
2478        using `region_wise_super_resolution` or `super_resolution_segmentation_per_label`.
2479
2480    target_range : list of float
2481        Intensity range used for scaling the input before applying the model.
2482        Default is [1, 0] (internal default for `apply_super_resolution_model_to_image`).
2483
2484    poly_order : int, str or None
2485        Determines how to match intensity between the super-resolved image and the original.
2486        Options:
2487          - 'hist' : use histogram matching
2488          - int >= 1 : perform polynomial regression of this order
2489          - None : no intensity adjustment
2490
2491    dilation_amount : int
2492        Number of dilation steps applied to each segmentation label during
2493        region-based super-resolution (if segmentation is provided).
2494
2495    verbose : bool
2496        If True, print progress and status messages.
2497
2498    Returns
2499    -------
2500    ANTsImage or dict
2501        - If `segmentation` is None, returns a single ANTsImage (super-resolved image).
2502        - If `segmentation` is provided, returns a dictionary with:
2503            - 'super_resolution': ANTsImage
2504            - other entries may include label-wise results or metadata.
2505
2506    Examples
2507    --------
2508    >>> import ants
2509    >>> import antspynet
2510    >>> from siq import inference
2511    >>> img = ants.image_read("lowres.nii.gz")
2512    >>> model = antspynet.get_pretrained_network("dbpn", target_suffix="T1")
2513    >>> srimg = inference(img, model, truncation=[0.01, 0.99], verbose=True)
2514
2515    >>> seg = ants.image_read("mask.nii.gz")
2516    >>> sr_result = inference(img, model, segmentation=seg)
2517    >>> srimg = sr_result['super_resolution']
2518    """
2519    import ants
2520    import numpy as np
2521    import antspynet
2522    import antspyt1w
2523    from siq import region_wise_super_resolution
2524
2525    def apply_intensity_match(sr_image, reference_image, order, verbose=False):
2526        if order is None:
2527            return sr_image
2528        if verbose:
2529            print("Applying intensity match with", order)
2530        if order == 'hist':
2531            return ants.histogram_match_image(sr_image, reference_image)
2532        else:
2533            return ants.regression_match_image(sr_image, reference_image, poly_order=order)
2534
2535    pimg = ants.image_clone(image)
2536    if truncation is not None:
2537        pimg = ants.iMath(pimg, 'TruncateIntensity', truncation[0], truncation[1])
2538
2539    input_shape = mdl.inputs[0].shape
2540    num_channels = int(input_shape[-1])
2541
2542    if segmentation is not None:
2543        if num_channels == 1:
2544            if verbose:
2545                print("Using region-wise super resolution due to single-channel model with segmentation.")
2546            sr = region_wise_super_resolution_blended(
2547                pimg, segmentation, mdl,
2548                dilation_amount=dilation_amount,
2549                verbose=verbose
2550            )
2551            ref = ants.resample_image_to_target(pimg, sr)
2552            return apply_intensity_match(sr, ref, poly_order, verbose)
2553        else:
2554            mynp = segmentation.numpy()
2555            mynp = list(np.unique(mynp)[1:len(mynp)].astype(int))
2556            upFactor = []
2557            if len(input_shape) == 5:
2558                testarr = np.zeros([1, 8, 8, 8, 2])
2559                testarrout = mdl(testarr)
2560                for k in range(3):
2561                    upFactor.append(int(testarrout.shape[k + 1] / testarr.shape[k + 1]))
2562            elif len(input_shape) == 4:
2563                testarr = np.zeros([1, 8, 8, 2])
2564                testarrout = mdl(testarr)
2565                for k in range(2):
2566                    upFactor.append(int(testarrout.shape[k + 1] / testarr.shape[k + 1]))
2567            temp = antspyt1w.super_resolution_segmentation_per_label(
2568                pimg,
2569                segmentation,
2570                upFactor,
2571                mdl,
2572                segmentation_numbers=mynp,
2573                target_range=target_range,
2574                dilation_amount=dilation_amount,
2575                poly_order=poly_order,
2576                max_lab_plus_one=True
2577            )
2578            imgsr = temp['super_resolution']
2579            ref = ants.resample_image_to_target(pimg, imgsr)
2580            return apply_intensity_match(imgsr, ref, poly_order, verbose)
2581
2582    # Default path: no segmentation
2583    imgsr = antspynet.apply_super_resolution_model_to_image(
2584        pimg, mdl, target_range=target_range, regression_order=None, verbose=verbose
2585    )
2586    ref = ants.resample_image_to_target(pimg, imgsr)
2587    return apply_intensity_match(imgsr, ref, poly_order, verbose)
DATA_PATH = '/Users/stnava/.siq/'
def get_data(name=None, force_download=False, version=0, target_extension='.csv'):
31def get_data( name=None, force_download=False, version=0, target_extension='.csv' ):
32    """
33    Get SIQ data filename
34
35    The first time this is called, it will download data to ~/.siq.
36    After, it will just read data from disk.  The ~/.siq may need to
37    be periodically deleted in order to ensure data is current.
38
39    Arguments
40    ---------
41    name : string
42        name of data tag to retrieve
43        Options:
44            - 'all'
45
46    force_download: boolean
47
48    version: version of data to download (integer)
49
50    Returns
51    -------
52    string
53        filepath of selected data
54
55    Example
56    -------
57    >>> import siq
58    >>> siq.get_data()
59    """
60    os.makedirs(DATA_PATH, exist_ok=True)
61
62    def download_data( version ):
63        url = "https://figshare.com/ndownloader/articles/16912366/versions/" + str(version)
64        target_file_name = "16912366.zip"
65        target_file_name_path = tf.keras.utils.get_file(target_file_name, url,
66            cache_subdir=DATA_PATH, extract = True )
67        os.remove( DATA_PATH + target_file_name )
68
69    if force_download:
70        download_data( version = version )
71
72
73    files = []
74    for fname in os.listdir(DATA_PATH):
75        if ( fname.endswith(target_extension) ) :
76            fname = os.path.join(DATA_PATH, fname)
77            files.append(fname)
78
79    if len( files ) == 0 :
80        download_data( version = version )
81        for fname in os.listdir(DATA_PATH):
82            if ( fname.endswith(target_extension) ) :
83                fname = os.path.join(DATA_PATH, fname)
84                files.append(fname)
85
86    if name == 'all':
87        return files
88
89    datapath = None
90
91    for fname in os.listdir(DATA_PATH):
92        mystem = (Path(fname).resolve().stem)
93        mystem = (Path(mystem).resolve().stem)
94        mystem = (Path(mystem).resolve().stem)
95        if ( name == mystem and fname.endswith(target_extension) ) :
96            datapath = os.path.join(DATA_PATH, fname)
97
98    return datapath

Get SIQ data filename

The first time this is called, it will download data to ~/.siq. After, it will just read data from disk. The ~/.siq may need to be periodically deleted in order to ensure data is current.

Arguments

name : string name of data tag to retrieve Options: - 'all'

force_download: boolean

version: version of data to download (integer)

Returns

string filepath of selected data

Example

>>> import siq
>>> siq.get_data()
def dbpn( input_image_size, number_of_outputs=1, number_of_base_filters=64, number_of_feature_filters=256, number_of_back_projection_stages=7, convolution_kernel_size=(12, 12), strides=(8, 8), last_convolution=(3, 3), number_of_loss_functions=1, interpolation='nearest'):
119def dbpn(input_image_size,
120                                                 number_of_outputs=1,
121                                                 number_of_base_filters=64,
122                                                 number_of_feature_filters=256,
123                                                 number_of_back_projection_stages=7,
124                                                 convolution_kernel_size=(12, 12),
125                                                 strides=(8, 8),
126                                                 last_convolution=(3, 3),
127                                                 number_of_loss_functions=1,
128                                                 interpolation = 'nearest'
129                                                ):
130    """
131    Creates a Deep Back-Projection Network (DBPN) for single image super-resolution.
132
133    This function constructs a Keras model based on the DBPN architecture, which
134    can be configured for either 2D or 3D inputs. The network uses iterative
135    up- and down-projection blocks to refine the high-resolution image estimate. A
136    key modification from the original paper is the option to use standard
137    interpolation for upsampling instead of deconvolution layers.
138
139    Reference:
140     - Haris, M., Shakhnarovich, G., & Ukita, N. (2018). Deep Back-Projection
141       Networks For Super-Resolution. In CVPR.
142
143    Parameters
144    ----------
145    input_image_size : tuple or list
146        The shape of the input image, including the channel.
147        e.g., `(None, None, 1)` for 2D or `(None, None, None, 1)` for 3D.
148
149    number_of_outputs : int, optional
150        The number of channels in the output image. Default is 1.
151
152    number_of_base_filters : int, optional
153        The number of filters in the up/down projection blocks. Default is 64.
154
155    number_of_feature_filters : int, optional
156        The number of filters in the initial feature extraction layer. Default is 256.
157
158    number_of_back_projection_stages : int, optional
159        The number of iterative back-projection stages (T in the paper). Default is 7.
160
161    convolution_kernel_size : tuple or list, optional
162        The kernel size for the main projection convolutions. Should match the
163        dimensionality of the input. Default is (12, 12).
164
165    strides : tuple or list, optional
166        The strides for the up/down sampling operations, defining the
167        super-resolution factor. Default is (8, 8).
168
169    last_convolution : tuple or list, optional
170        The kernel size of the final reconstruction convolution. Default is (3, 3).
171
172    number_of_loss_functions : int, optional
173        If greater than 1, the model will have multiple identical output branches.
174        Typically set to 1. Default is 1.
175
176    interpolation : str, optional
177        The interpolation method to use for upsampling layers if not using
178        transposed convolutions. 'nearest' or 'bilinear'. Default is 'nearest'.
179
180    Returns
181    -------
182    keras.Model
183        A Keras model implementing the DBPN architecture for the specified
184        parameters.
185    """
186    idim = len( input_image_size ) - 1
187    if idim == 2:
188        myconv = Conv2D
189        myconv_transpose = Conv2DTranspose
190        myupsampling = UpSampling2D
191        shax = ( 1, 2 )
192        firstConv = (3,3)
193        firstStrides=(1,1)
194        smashConv=(1,1)
195    if idim == 3:
196        myconv = Conv3D
197        myconv_transpose = Conv3DTranspose
198        myupsampling = UpSampling3D
199        shax = ( 1, 2, 3 )
200        firstConv = (3,3,3)
201        firstStrides=(1,1,1)
202        smashConv=(1,1,1)
203    def up_block_2d(L, number_of_filters=64, kernel_size=(12, 12), strides=(8, 8),
204                    include_dense_convolution_layer=True):
205        if include_dense_convolution_layer == True:
206            L = myconv(filters = number_of_filters,
207                       use_bias=True,
208                       kernel_size=smashConv,
209                       strides=firstStrides,
210                       padding='same')(L)
211            L = PReLU(alpha_initializer='zero',
212                      shared_axes=shax)(L)
213        # Scale up
214        if idim == 2:
215            H0 = myupsampling( size = strides, interpolation=interpolation )(L)
216        if idim == 3:
217            H0 = myupsampling( size = strides )(L)
218        H0 = myconv(filters=number_of_filters,
219                    kernel_size=firstConv,
220                    strides=firstStrides,
221                    use_bias=True,
222                    padding='same')(H0)
223        H0 = PReLU(alpha_initializer='zero',
224                   shared_axes=shax)(H0)
225        # Scale down
226        L0 = myconv(filters=number_of_filters,
227                    kernel_size=kernel_size,
228                    strides=strides,
229                    kernel_initializer='glorot_uniform',
230                    padding='same')(H0)
231        L0 = PReLU(alpha_initializer='zero',
232                   shared_axes=shax)(L0)
233        # Residual
234        E = Subtract()([L0, L])
235        # Scale residual up
236        if idim == 2:
237            H1 = myupsampling( size = strides, interpolation=interpolation  )(E)
238        if idim == 3:
239            H1 = myupsampling( size = strides )(E)
240        H1 = myconv(filters=number_of_filters,
241                    kernel_size=firstConv,
242                    strides=firstStrides,
243                    use_bias=True,
244                    padding='same')(H1)
245        H1 = PReLU(alpha_initializer='zero',
246                   shared_axes=shax)(H1)
247        # Output feature map
248        up_block = Add()([H0, H1])
249        return up_block
250    def down_block_2d(H, number_of_filters=64, kernel_size=(12, 12), strides=(8, 8),
251                    include_dense_convolution_layer=True):
252        if include_dense_convolution_layer == True:
253            H = myconv(filters = number_of_filters,
254                       use_bias=True,
255                       kernel_size=smashConv,
256                       strides=firstStrides,
257                       padding='same')(H)
258            H = PReLU(alpha_initializer='zero',
259                      shared_axes=shax)(H)
260        # Scale down
261        L0 = myconv(filters=number_of_filters,
262                    kernel_size=kernel_size,
263                    strides=strides,
264                    kernel_initializer='glorot_uniform',
265                    padding='same')(H)
266        L0 = PReLU(alpha_initializer='zero',
267                   shared_axes=shax)(L0)
268        # Scale up
269        if idim == 2:
270            H0 = myupsampling( size = strides, interpolation=interpolation )(L0)
271        if idim == 3:
272            H0 = myupsampling( size = strides )(L0)
273        H0 = myconv(filters=number_of_filters,
274                    kernel_size=firstConv,
275                    strides=firstStrides,
276                    use_bias=True,
277                    padding='same')(H0)
278        H0 = PReLU(alpha_initializer='zero',
279                   shared_axes=shax)(H0)
280        # Residual
281        E = Subtract()([H0, H])
282        # Scale residual down
283        L1 = myconv(filters=number_of_filters,
284                    kernel_size=kernel_size,
285                    strides=strides,
286                    kernel_initializer='glorot_uniform',
287                    padding='same')(E)
288        L1 = PReLU(alpha_initializer='zero',
289                   shared_axes=shax)(L1)
290        # Output feature map
291        down_block = Add()([L0, L1])
292        return down_block
293    inputs = Input(shape=input_image_size)
294    # Initial feature extraction
295    model = myconv(filters=number_of_feature_filters,
296                   kernel_size=firstConv,
297                   strides=firstStrides,
298                   padding='same',
299                   kernel_initializer='glorot_uniform')(inputs)
300    model = PReLU(alpha_initializer='zero',
301                  shared_axes=shax)(model)
302    # Feature smashing
303    model = myconv(filters=number_of_base_filters,
304                   kernel_size=smashConv,
305                   strides=firstStrides,
306                   padding='same',
307                   kernel_initializer='glorot_uniform')(model)
308    model = PReLU(alpha_initializer='zero',
309                  shared_axes=shax)(model)
310    # Back projection
311    up_projection_blocks = []
312    down_projection_blocks = []
313    model = up_block_2d(model, number_of_filters=number_of_base_filters,
314      kernel_size=convolution_kernel_size, strides=strides)
315    up_projection_blocks.append(model)
316    for i in range(number_of_back_projection_stages):
317        if i == 0:
318            model = down_block_2d(model, number_of_filters=number_of_base_filters,
319              kernel_size=convolution_kernel_size, strides=strides)
320            down_projection_blocks.append(model)
321            model = up_block_2d(model, number_of_filters=number_of_base_filters,
322              kernel_size=convolution_kernel_size, strides=strides)
323            up_projection_blocks.append(model)
324            model = Concatenate()(up_projection_blocks)
325        else:
326            model = down_block_2d(model, number_of_filters=number_of_base_filters,
327              kernel_size=convolution_kernel_size, strides=strides,
328              include_dense_convolution_layer=True)
329            down_projection_blocks.append(model)
330            model = Concatenate()(down_projection_blocks)
331            model = up_block_2d(model, number_of_filters=number_of_base_filters,
332              kernel_size=convolution_kernel_size, strides=strides,
333              include_dense_convolution_layer=True)
334            up_projection_blocks.append(model)
335            model = Concatenate()(up_projection_blocks)
336    outputs = myconv(filters=number_of_outputs,
337                     kernel_size=last_convolution,
338                     strides=firstStrides,
339                     padding = 'same',
340                     kernel_initializer = "glorot_uniform")(model)
341    if number_of_loss_functions == 1:
342        deep_back_projection_network_model = Model(inputs=inputs, outputs=outputs)
343    else:
344        outputList=[]
345        for k in range(number_of_loss_functions):
346            outputList.append(outputs)
347        deep_back_projection_network_model = Model(inputs=inputs, outputs=outputList)
348    return deep_back_projection_network_model

Creates a Deep Back-Projection Network (DBPN) for single image super-resolution.

This function constructs a Keras model based on the DBPN architecture, which can be configured for either 2D or 3D inputs. The network uses iterative up- and down-projection blocks to refine the high-resolution image estimate. A key modification from the original paper is the option to use standard interpolation for upsampling instead of deconvolution layers.

Reference:

  • Haris, M., Shakhnarovich, G., & Ukita, N. (2018). Deep Back-Projection Networks For Super-Resolution. In CVPR.

Parameters

input_image_size : tuple or list The shape of the input image, including the channel. e.g., (None, None, 1) for 2D or (None, None, None, 1) for 3D.

number_of_outputs : int, optional The number of channels in the output image. Default is 1.

number_of_base_filters : int, optional The number of filters in the up/down projection blocks. Default is 64.

number_of_feature_filters : int, optional The number of filters in the initial feature extraction layer. Default is 256.

number_of_back_projection_stages : int, optional The number of iterative back-projection stages (T in the paper). Default is 7.

convolution_kernel_size : tuple or list, optional The kernel size for the main projection convolutions. Should match the dimensionality of the input. Default is (12, 12).

strides : tuple or list, optional The strides for the up/down sampling operations, defining the super-resolution factor. Default is (8, 8).

last_convolution : tuple or list, optional The kernel size of the final reconstruction convolution. Default is (3, 3).

number_of_loss_functions : int, optional If greater than 1, the model will have multiple identical output branches. Typically set to 1. Default is 1.

interpolation : str, optional The interpolation method to use for upsampling layers if not using transposed convolutions. 'nearest' or 'bilinear'. Default is 'nearest'.

Returns

keras.Model A Keras model implementing the DBPN architecture for the specified parameters.

def get_random_base_ind(full_dims, patchWidth, off=8):
353def get_random_base_ind( full_dims, patchWidth, off=8 ):
354    """
355    Generates a random top-left corner index for a patch.
356
357    This utility function computes a valid starting index (e.g., [x, y, z])
358    for extracting a patch from a larger volume, ensuring the patch fits entirely
359    within the volume's boundaries, accounting for an offset.
360
361    Parameters
362    ----------
363    full_dims : tuple or list
364        The dimensions of the full volume (e.g., img.shape).
365
366    patchWidth : tuple or list
367        The dimensions of the patch to be extracted.
368
369    off : int, optional
370        An offset from the edge of the volume to avoid sampling near borders.
371        Default is 8.
372
373    Returns
374    -------
375    list
376        A list of integers representing the starting coordinates for the patch.
377    """
378    baseInd = [None,None,None]
379    for k in range(3):
380        baseInd[k]=random.sample( range( off, full_dims[k]-1-patchWidth[k] ), 1 )[0]
381    return baseInd

Generates a random top-left corner index for a patch.

This utility function computes a valid starting index (e.g., [x, y, z]) for extracting a patch from a larger volume, ensuring the patch fits entirely within the volume's boundaries, accounting for an offset.

Parameters

full_dims : tuple or list The dimensions of the full volume (e.g., img.shape).

patchWidth : tuple or list The dimensions of the patch to be extracted.

off : int, optional An offset from the edge of the volume to avoid sampling near borders. Default is 8.

Returns

list A list of integers representing the starting coordinates for the patch.

def get_random_patch(img, patchWidth):
385def get_random_patch( img, patchWidth ):
386    """
387    Extracts a random patch from an image with non-zero variance.
388
389    This function repeatedly samples a random patch of a specified width from
390    the input image until it finds one where the standard deviation of pixel
391    intensities is greater than zero. This is useful for avoiding blank or
392    uniform patches during training data generation.
393
394    Parameters
395    ----------
396    img : ants.ANTsImage
397        The source image from which to extract a patch.
398
399    patchWidth : tuple or list
400        The desired dimensions of the output patch.
401
402    Returns
403    -------
404    ants.ANTsImage
405        A randomly extracted patch from the input image.
406    """
407    mystd = 0
408    while mystd == 0:
409        inds = get_random_base_ind( full_dims = img.shape, patchWidth=patchWidth, off=8 )
410        hinds = [None,None,None]
411        for k in range(len(inds)):
412            hinds[k] = inds[k] + patchWidth[k]
413        myimg = ants.crop_indices( img, inds, hinds )
414        mystd = myimg.std()
415    return myimg

Extracts a random patch from an image with non-zero variance.

This function repeatedly samples a random patch of a specified width from the input image until it finds one where the standard deviation of pixel intensities is greater than zero. This is useful for avoiding blank or uniform patches during training data generation.

Parameters

img : ants.ANTsImage The source image from which to extract a patch.

patchWidth : tuple or list The desired dimensions of the output patch.

Returns

ants.ANTsImage A randomly extracted patch from the input image.

def get_random_patch_pair(img, img2, patchWidth):
417def get_random_patch_pair( img, img2, patchWidth ):
418    """
419    Extracts a corresponding random patch from a pair of images.
420
421    This function finds a single random location and extracts a patch of the
422    same size and position from two different input images. It ensures that
423    both extracted patches have non-zero variance. This is useful for creating
424    paired training data (e.g., low-res and high-res images).
425
426    Parameters
427    ----------
428    img : ants.ANTsImage
429        The first source image.
430
431    img2 : ants.ANTsImage
432        The second source image, spatially aligned with the first.
433
434    patchWidth : tuple or list
435        The desired dimensions of the output patches.
436
437    Returns
438    -------
439    tuple of ants.ANTsImage
440        A tuple containing two corresponding patches: (patch_from_img, patch_from_img2).
441    """
442    mystd = mystd2 = 0
443    ct = 0
444    while mystd == 0 or mystd2 == 0:
445        inds = get_random_base_ind( full_dims = img.shape, patchWidth=patchWidth, off=8  )
446        hinds = [None,None,None]
447        for k in range(len(inds)):
448            hinds[k] = inds[k] + patchWidth[k]
449        myimg = ants.crop_indices( img, inds, hinds )
450        myimg2 = ants.crop_indices( img2, inds, hinds )
451        mystd = myimg.std()
452        mystd2 = myimg2.std()
453        ct = ct + 1
454        if ( ct > 20 ):
455            return myimg, myimg2
456    return myimg, myimg2

Extracts a corresponding random patch from a pair of images.

This function finds a single random location and extracts a patch of the same size and position from two different input images. It ensures that both extracted patches have non-zero variance. This is useful for creating paired training data (e.g., low-res and high-res images).

Parameters

img : ants.ANTsImage The first source image.

img2 : ants.ANTsImage The second source image, spatially aligned with the first.

patchWidth : tuple or list The desired dimensions of the output patches.

Returns

tuple of ants.ANTsImage A tuple containing two corresponding patches: (patch_from_img, patch_from_img2).

def pseudo_3d_vgg_features( inshape=[128, 128, 128], layer=4, angle=0, pretrained=True, verbose=False):
458def pseudo_3d_vgg_features( inshape = [128,128,128], layer = 4, angle=0, pretrained=True, verbose=False ):
459    """
460    Creates a pseudo-3D VGG feature extractor from a pre-trained 2D VGG model.
461
462    This function constructs a 3D VGG-style network and initializes its weights
463    by "stretching" the weights from a pre-trained 2D VGG19 model (trained on
464    ImageNet) along a specified axis. This is a technique to transfer 2D
465    perceptual knowledge to a 3D domain for tasks like perceptual loss.
466
467    Parameters
468    ----------
469    inshape : list of int, optional
470        The input shape of the 3D volume, e.g., `[128, 128, 128]`. Default is `[128,128,128]`.
471
472    layer : int, optional
473        The block number of the VGG network from which to extract features. For
474        VGG19, this corresponds to block `layer` (e.g., layer=4 means 'block4_conv...').
475        Default is 4.
476
477    angle : int, optional
478        The axis along which to project the 2D weights:
479        - 0: Axial plane (stretches along Z)
480        - 1: Coronal plane (stretches along Y)
481        - 2: Sagittal plane (stretches along X)
482        Default is 0.
483
484    pretrained : bool, optional
485        If True, loads the stretched ImageNet weights. If False, the model is
486        randomly initialized. Default is True.
487
488    verbose : bool, optional
489        If True, prints information about the layers being used. Default is False.
490
491    Returns
492    -------
493    tf.keras.Model
494        A Keras model that takes a 3D volume as input and outputs the pseudo-3D
495        feature map from the specified layer and angle.
496    """
497    def getLayerScaleFactorForTransferLearning( k, w3d, w2d ):
498        myfact = np.round( np.prod( w3d[k].shape ) / np.prod(  w2d[k].shape) )
499        return myfact
500    vgg19 = tf.keras.applications.VGG19(
501            include_top = False, weights = "imagenet",
502            input_shape = [inshape[0],inshape[1],3],
503            classes = 1000 )
504    def findLayerIndex( layerName, mdl ):
505          for k in range( len( mdl.layers ) ):
506            if layerName == mdl.layers[k].name :
507                return k - 1
508          return None
509    layer_index = layer-1 # findLayerIndex( 'block2_conv2', vgg19 )
510    vggmodelRaw = antspynet.create_vgg_model_3d(
511            [inshape[0],inshape[1],inshape[2],1],
512            number_of_classification_labels = 1000,
513            layers = [1, 2, 3, 4, 4],
514            lowest_resolution = 64,
515            convolution_kernel_size= (3, 3, 3), pool_size = (2, 2, 2),
516            strides = (2, 2, 2), number_of_dense_units= 4096, dropout_rate = 0,
517            style = 19, mode = "classification")
518    if verbose:
519        print( vggmodelRaw.layers[layer_index] )
520        print( vggmodelRaw.layers[layer_index].name )
521        print( vgg19.layers[layer_index] )
522        print( vgg19.layers[layer_index].name )
523    feature_extractor_2d = tf.keras.Model(
524            inputs = vgg19.input,
525            outputs = vgg19.layers[layer_index].output)
526    feature_extractor = tf.keras.Model(
527            inputs = vggmodelRaw.input,
528            outputs = vggmodelRaw.layers[layer_index].output)
529    wts_2d = feature_extractor_2d.weights
530    wts = feature_extractor.weights
531    def checkwtshape( a, b ):
532        if len(a.shape) != len(b.shape):
533                return False
534        for j in range(len(a.shape)):
535            if a.shape[j] != b.shape[j]:
536                return False
537        return True
538    for ww in range(len(wts)):
539        wts[ww]=wts[ww].numpy()
540        wts_2d[ww]=wts_2d[ww].numpy()
541        if checkwtshape( wts[ww], wts_2d[ww] ) and ww != 0:
542            wts[ww]=wts_2d[ww]
543        elif ww != 0:
544            # FIXME - should allow doing this across different angles
545            if angle == 0:
546                wts[ww][:,:,0,:,:]=wts_2d[ww]/3.0
547                wts[ww][:,:,1,:,:]=wts_2d[ww]/3.0
548                wts[ww][:,:,2,:,:]=wts_2d[ww]/3.0
549            if angle == 1:
550                wts[ww][:,0,:,:,:]=wts_2d[ww]/3.0
551                wts[ww][:,1,:,:,:]=wts_2d[ww]/3.0
552                wts[ww][:,2,:,:,:]=wts_2d[ww]/3.0
553            if angle == 2:
554                wts[ww][0,:,:,:,:]=wts_2d[ww]/3.0
555                wts[ww][1,:,:,:,:]=wts_2d[ww]/3.0
556                wts[ww][2,:,:,:,:]=wts_2d[ww]/3.0
557        else:
558            wts[ww][:,:,:,0,:]=wts_2d[ww]
559    if pretrained:
560        feature_extractor.set_weights( wts )
561        newinput = tf.keras.layers.Rescaling(  255.0, -127.5  )( feature_extractor.input )
562        feature_extractor2 = feature_extractor( newinput )
563        feature_extractor = tf.keras.Model( feature_extractor.input, feature_extractor2 )
564    return feature_extractor

Creates a pseudo-3D VGG feature extractor from a pre-trained 2D VGG model.

This function constructs a 3D VGG-style network and initializes its weights by "stretching" the weights from a pre-trained 2D VGG19 model (trained on ImageNet) along a specified axis. This is a technique to transfer 2D perceptual knowledge to a 3D domain for tasks like perceptual loss.

Parameters

inshape : list of int, optional The input shape of the 3D volume, e.g., [128, 128, 128]. Default is [128,128,128].

layer : int, optional The block number of the VGG network from which to extract features. For VGG19, this corresponds to block layer (e.g., layer=4 means 'block4_conv...'). Default is 4.

angle : int, optional The axis along which to project the 2D weights: - 0: Axial plane (stretches along Z) - 1: Coronal plane (stretches along Y) - 2: Sagittal plane (stretches along X) Default is 0.

pretrained : bool, optional If True, loads the stretched ImageNet weights. If False, the model is randomly initialized. Default is True.

verbose : bool, optional If True, prints information about the layers being used. Default is False.

Returns

tf.keras.Model A Keras model that takes a 3D volume as input and outputs the pseudo-3D feature map from the specified layer and angle.

def pseudo_3d_vgg_features_unbiased(inshape=[128, 128, 128], layer=4, verbose=False):
566def pseudo_3d_vgg_features_unbiased( inshape = [128,128,128], layer = 4, verbose=False ):
567    """
568    Create a pseudo-3D VGG-style feature extractor by aggregating axial, coronal,
569    and sagittal VGG feature representations.
570
571    This model extracts features along each principal axis using pre-trained 2D
572    VGG-style networks and concatenates them to form an unbiased pseudo-3D feature space.
573
574    Parameters
575    ----------
576    inshape : list of int, optional
577        The input shape of the 3D volume, default is [128, 128, 128].
578
579    layer : int, optional
580        The VGG feature layer to extract. Higher values correspond to deeper
581        layers in the pseudo-3D VGG backbone.
582
583    verbose : bool, optional
584        If True, prints debug messages during model construction.
585
586    Returns
587    -------
588    tf.keras.Model
589        A TensorFlow Keras model that takes a 3D input volume and outputs the
590        concatenated pseudo-3D feature representation from the specified layer.
591
592    Notes
593    -----
594    This is useful for perceptual loss or feature comparison in super-resolution
595    and image synthesis tasks. The same input is processed in three anatomical
596    planes (axial, coronal, sagittal), and features are concatenated.
597
598    See Also
599    --------
600    pseudo_3d_vgg_features : Generates VGG features from a single anatomical plane.
601    """
602    f = [
603        pseudo_3d_vgg_features( inshape, layer, angle=0, pretrained=True, verbose=verbose ),
604        pseudo_3d_vgg_features( inshape, layer, angle=1, pretrained=True ),
605        pseudo_3d_vgg_features( inshape, layer, angle=2, pretrained=True ) ]
606    f1=f[0].inputs
607    f0o=f[0]( f1 )
608    f1o=f[1]( f1 )
609    f2o=f[2]( f1 )
610    catter = tf.keras.layers.concatenate( [f0o, f1o, f2o ])
611    feature_extractor = tf.keras.Model( f1, catter )
612    return feature_extractor

Create a pseudo-3D VGG-style feature extractor by aggregating axial, coronal, and sagittal VGG feature representations.

This model extracts features along each principal axis using pre-trained 2D VGG-style networks and concatenates them to form an unbiased pseudo-3D feature space.

Parameters

inshape : list of int, optional The input shape of the 3D volume, default is [128, 128, 128].

layer : int, optional The VGG feature layer to extract. Higher values correspond to deeper layers in the pseudo-3D VGG backbone.

verbose : bool, optional If True, prints debug messages during model construction.

Returns

tf.keras.Model A TensorFlow Keras model that takes a 3D input volume and outputs the concatenated pseudo-3D feature representation from the specified layer.

Notes

This is useful for perceptual loss or feature comparison in super-resolution and image synthesis tasks. The same input is processed in three anatomical planes (axial, coronal, sagittal), and features are concatenated.

See Also

pseudo_3d_vgg_features : Generates VGG features from a single anatomical plane.

def get_grader_feature_network(layer=6):
614def get_grader_feature_network( layer=6 ):
615    """
616    Load and extract a ResNet-based feature subnetwork for perceptual loss or quality grading.
617
618    This function loads a pre-trained 3D ResNet model ("grader") used for
619    perceptual feature extraction and returns a subnetwork that outputs activations
620    from a specified internal layer.
621
622    Parameters
623    ----------
624    layer : int, optional
625        The index of the internal ResNet layer whose output should be used as
626        the feature representation. Default is layer 6.
627
628    Returns
629    -------
630    tf.keras.Model
631        A Keras model that outputs features from the specified layer of the
632        pre-trained 3D ResNet grader model.
633
634    Raises
635    ------
636    Exception
637        If the pre-trained weights file (`resnet_grader.h5`) is not found.
638
639    Notes
640    -----
641    The pre-trained weights should be located in: `~/.antspyt1w/resnet_grader.keras`
642
643    This model is typically used to compute perceptual loss by comparing
644    intermediate activations between target and prediction volumes.
645
646    See Also
647    --------
648    antspynet.create_resnet_model_3d : Constructs the base ResNet model.
649    """
650    grader = antspynet.create_resnet_model_3d(
651        [None,None,None,1],
652        lowest_resolution = 32,
653        number_of_outputs = 4,
654        cardinality = 1,
655        squeeze_and_excite = False )
656    # the folder and data below as available from antspyt1w get_data
657    graderfn = os.path.expanduser( "~/.antspyt1w/resnet_grader.h5" )
658    if not exists( graderfn ):
659        raise Exception("graderfn " + graderfn + " does not exist")
660    grader.load_weights( graderfn)
661    #    feature_extractor_23 = tf.keras.Model( inputs=grader.inputs, outputs=grader.layers[23].output )
662    #   feature_extractor_44 = tf.keras.Model( inputs=grader.inputs, outputs=grader.layers[44].output )
663    return tf.keras.Model( inputs=grader.inputs, outputs=grader.layers[layer].output )

Load and extract a ResNet-based feature subnetwork for perceptual loss or quality grading.

This function loads a pre-trained 3D ResNet model ("grader") used for perceptual feature extraction and returns a subnetwork that outputs activations from a specified internal layer.

Parameters

layer : int, optional The index of the internal ResNet layer whose output should be used as the feature representation. Default is layer 6.

Returns

tf.keras.Model A Keras model that outputs features from the specified layer of the pre-trained 3D ResNet grader model.

Raises

Exception If the pre-trained weights file (resnet_grader.h5) is not found.

Notes

The pre-trained weights should be located in: ~/.antspyt1w/resnet_grader.keras

This model is typically used to compute perceptual loss by comparing intermediate activations between target and prediction volumes.

See Also

antspynet.create_resnet_model_3d : Constructs the base ResNet model.

def default_dbpn( strider, dimensionality=3, nfilt=64, nff=256, convn=6, lastconv=3, nbp=7, nChannelsIn=1, nChannelsOut=1, option=None, intensity_model=None, segmentation_model=None, sigmoid_second_channel=False, clone_intensity_to_segmentation=False, pro_seg=0, freeze=False, verbose=False):
666def default_dbpn(
667    strider, # length should equal dimensionality
668    dimensionality = 3,
669    nfilt=64,
670    nff = 256,
671    convn = 6,
672    lastconv = 3,
673    nbp=7,
674    nChannelsIn=1,
675    nChannelsOut=1,
676    option = None,
677    intensity_model=None,
678    segmentation_model=None,
679    sigmoid_second_channel=False,
680    clone_intensity_to_segmentation=False,
681    pro_seg = 0,
682    freeze = False,
683    verbose=False
684 ):
685    """
686    Constructs a DBPN model based on input parameters, and can optionally
687    use external models for intensity or segmentation processing.
688
689    Args:
690        strider (list): List of strides, length must match `dimensionality`.
691        dimensionality (int): Number of dimensions (2 or 3). Default is 3.
692        nfilt (int): Number of base filters. Default is 64.
693        nff (int): Number of feature filters. Default is 256.
694        convn (int): Convolution kernel size. Default is 6.
695        lastconv (int): Size of the last convolution. Default is 3.
696        nbp (int): Number of back projection stages. Default is 7.
697        nChannelsIn (int): Number of input channels. Default is 1.
698        nChannelsOut (int): Number of output channels. Default is 1.
699        option (str): Model size option ('tiny', 'small', 'medium', 'large'). Default is None.
700        intensity_model (tf.keras.Model): Optional external intensity model.
701        segmentation_model (tf.keras.Model): Optional external segmentation model.
702        sigmoid_second_channel (bool): If True, applies sigmoid to second channel in output.
703        clone_intensity_to_segmentation (bool): If True, clones intensity model weights to segmentation.
704        pro_seg (int): If greater than 0, adds a segmentation arm.
705        freeze (bool): If True, freezes the layers of the intensity/segmentation models.
706        verbose (bool): If True, prints detailed logs.
707
708    Returns:
709        Model: A Keras model based on the specified configuration.
710
711    Raises:
712        Exception: If `len(strider)` is not equal to `dimensionality`.
713    """
714    if option == 'tiny':
715        nfilt=32
716        nff = 64
717        convn = 3
718        lastconv = 1
719        nbp=2
720    elif option == 'small':
721        nfilt=32
722        nff = 64
723        convn = 6
724        lastconv = 3
725        nbp=4
726    elif option == 'medium':
727        nfilt=64
728        nff = 128
729        convn = 6
730        lastconv = 3
731        nbp=4
732    else:
733        option='large'
734    if verbose:
735        print("Building mode of size: " + option)
736        if intensity_model is not None:
737            print("user-passed intensity model will be frozen - only segmentation will train")
738        if segmentation_model is not None:
739            print("user-passed segmentation model will be frozen - only intensity will train")
740
741    if len(strider) != dimensionality:
742        raise Exception("len(strider) != dimensionality")
743    # **model instantiation**: these are close to defaults for the 2x network.<br>
744    # empirical evidence suggests that making covolutions and strides evenly<br>
745    # divisible by each other reduces artifacts.  2*3=6.
746    # ofn='./models/dsr3d_'+str(strider)+'up_' + str(nfilt) + '_' + str( nff ) + '_' + str(convn)+ '_' + str(lastconv)+ '_' + str(os.environ['CUDA_VISIBLE_DEVICES'])+'_v0.0.keras'
747    if dimensionality == 2 :
748        mdl = dbpn( (None,None,nChannelsIn),
749            number_of_outputs=nChannelsOut,
750            number_of_base_filters=nfilt,
751            number_of_feature_filters=nff,
752            number_of_back_projection_stages=nbp,
753            convolution_kernel_size=(convn, convn),
754            strides=(strider[0], strider[1]),
755            last_convolution=(lastconv, lastconv),
756            number_of_loss_functions=1,
757            interpolation='nearest')
758    if dimensionality == 3 :
759        mdl = dbpn( (None,None,None,nChannelsIn),
760            number_of_outputs=nChannelsOut,
761            number_of_base_filters=nfilt,
762            number_of_feature_filters=nff,
763            number_of_back_projection_stages=nbp,
764            convolution_kernel_size=(convn, convn, convn),
765            strides=(strider[0], strider[1], strider[2]),
766            last_convolution=(lastconv, lastconv, lastconv), number_of_loss_functions=1, interpolation='nearest')
767    if sigmoid_second_channel and pro_seg != 0 :
768        if dimensionality == 2 :
769            input_image_size = (None,None,2)
770            if intensity_model is None:
771                intensity_model = dbpn( (None,None,1),
772                    number_of_outputs=1,
773                    number_of_base_filters=nfilt,
774                    number_of_feature_filters=nff,
775                    number_of_back_projection_stages=nbp,
776                    convolution_kernel_size=(convn, convn),
777                    strides=(strider[0], strider[1]),
778                    last_convolution=(lastconv, lastconv),
779                    number_of_loss_functions=1,
780                    interpolation='nearest')
781            else:
782                if freeze:
783                    for layer in intensity_model.layers:
784                        layer.trainable = False
785            if segmentation_model is None:
786                segmentation_model = dbpn( (None,None,1),
787                        number_of_outputs=1,
788                        number_of_base_filters=nfilt,
789                        number_of_feature_filters=nff,
790                        number_of_back_projection_stages=nbp,
791                        convolution_kernel_size=(convn, convn),
792                        strides=(strider[0], strider[1]),
793                        last_convolution=(lastconv, lastconv),
794                        number_of_loss_functions=1, interpolation='linear')
795            else:
796                if freeze:
797                    for layer in segmentation_model.layers:
798                        layer.trainable = False
799        if dimensionality == 3 :
800            input_image_size = (None,None,None,2)
801            if intensity_model is None:
802                intensity_model = dbpn( (None,None,None,1),
803                    number_of_outputs=1,
804                    number_of_base_filters=nfilt,
805                    number_of_feature_filters=nff,
806                    number_of_back_projection_stages=nbp,
807                    convolution_kernel_size=(convn, convn, convn),
808                    strides=(strider[0], strider[1], strider[2]),
809                    last_convolution=(lastconv, lastconv, lastconv),
810                    number_of_loss_functions=1, interpolation='nearest')
811            else:
812                if freeze:
813                    for layer in intensity_model.layers:
814                        layer.trainable = False
815            if segmentation_model is None:
816                segmentation_model = dbpn( (None,None,None,1),
817                        number_of_outputs=1,
818                        number_of_base_filters=nfilt,
819                        number_of_feature_filters=nff,
820                        number_of_back_projection_stages=nbp,
821                        convolution_kernel_size=(convn, convn, convn),
822                        strides=(strider[0], strider[1], strider[2]),
823                        last_convolution=(lastconv, lastconv, lastconv),
824                        number_of_loss_functions=1, interpolation='linear')
825            else:
826                if freeze:
827                    for layer in segmentation_model.layers:
828                        layer.trainable = False
829        if verbose:
830            print( "len intensity_model layers : " + str( len( intensity_model.layers )))
831            print( "len intensity_model weights : " + str( len( intensity_model.weights )))
832            print( "len segmentation_model layers : " + str( len( segmentation_model.layers )))
833            print( "len segmentation_model weights : " + str( len( segmentation_model.weights )))
834        if clone_intensity_to_segmentation:
835            for k in range(len( segmentation_model.weights )):
836                if k < len( intensity_model.weights ):
837                    if intensity_model.weights[k].shape == segmentation_model.weights[k].shape:
838                        segmentation_model.weights[k] = intensity_model.weights[k]
839        inputs = tf.keras.Input(shape=input_image_size)
840        insplit = tf.split( inputs, 2, dimensionality+1)
841        outputs = [
842            intensity_model( insplit[0] ),
843            tf.nn.sigmoid( segmentation_model( insplit[1] ) ) ]
844        mdlout = tf.concat( outputs, axis=dimensionality+1 )
845        return Model(inputs=inputs, outputs=mdlout )
846    if pro_seg > 0 and intensity_model is not None:
847        if verbose and freeze:
848            print("Add a segmentation arm to the end. freeze intensity. intensity_model(seg) => conv => sigmoid")
849        if verbose and not freeze:
850            print("Add a segmentation arm to the end. freeze intensity. intensity_model(seg) => conv => sigmoid")
851        if freeze:
852            for layer in intensity_model.layers:
853                layer.trainable = False
854        if dimensionality == 2 :
855            input_image_size = (None,None,2)
856        elif dimensionality == 3 :
857            input_image_size = (None, None,None,2)
858        if dimensionality == 2:
859            myconv = Conv2D
860            firstConv = (convn,convn)
861            firstStrides=(1,1)
862            smashConv=(pro_seg,pro_seg)
863        if dimensionality == 3:
864            myconv = Conv3D
865            firstConv = (convn,convn,convn)
866            firstStrides=(1,1,1)
867            smashConv=(pro_seg,pro_seg,pro_seg)
868        inputs = tf.keras.Input(shape=input_image_size)
869        insplit = tf.split( inputs, 2, dimensionality+1)
870        # define segmentation arm
871        seggit = intensity_model( insplit[1] )
872        L0 = myconv(filters=nff,
873                    kernel_size=firstConv,
874                    strides=firstStrides,
875                    kernel_initializer='glorot_uniform',
876                    padding='same')(seggit)
877        L1 = myconv(filters=nff,
878                    kernel_size=firstConv,
879                    strides=firstStrides,
880                    kernel_initializer='glorot_uniform',
881                    padding='same')(L0)
882        L2 = myconv(filters=1,
883                    kernel_size=smashConv,
884                    strides=firstStrides,
885                    kernel_initializer='glorot_uniform',
886                    padding='same')(L1)
887        outputs = [
888            intensity_model( insplit[0] ),
889            tf.nn.sigmoid( L2 ) ]
890        mdlout = tf.concat( outputs, axis=dimensionality+1 )
891        return Model(inputs=inputs, outputs=mdlout )
892    return mdl

Constructs a DBPN model based on input parameters, and can optionally use external models for intensity or segmentation processing.

Args: strider (list): List of strides, length must match dimensionality. dimensionality (int): Number of dimensions (2 or 3). Default is 3. nfilt (int): Number of base filters. Default is 64. nff (int): Number of feature filters. Default is 256. convn (int): Convolution kernel size. Default is 6. lastconv (int): Size of the last convolution. Default is 3. nbp (int): Number of back projection stages. Default is 7. nChannelsIn (int): Number of input channels. Default is 1. nChannelsOut (int): Number of output channels. Default is 1. option (str): Model size option ('tiny', 'small', 'medium', 'large'). Default is None. intensity_model (tf.keras.Model): Optional external intensity model. segmentation_model (tf.keras.Model): Optional external segmentation model. sigmoid_second_channel (bool): If True, applies sigmoid to second channel in output. clone_intensity_to_segmentation (bool): If True, clones intensity model weights to segmentation. pro_seg (int): If greater than 0, adds a segmentation arm. freeze (bool): If True, freezes the layers of the intensity/segmentation models. verbose (bool): If True, prints detailed logs.

Returns: Model: A Keras model based on the specified configuration.

Raises: Exception: If len(strider) is not equal to dimensionality.

def image_patch_training_data_from_filenames( filenames, target_patch_size, target_patch_size_low, nPatches=128, istest=False, patch_scaler=True, to_tensorflow=False, verbose=False):
 894def image_patch_training_data_from_filenames(
 895    filenames,
 896    target_patch_size,
 897    target_patch_size_low,
 898    nPatches = 128,
 899    istest   = False,
 900    patch_scaler=True,
 901    to_tensorflow = False,
 902    verbose = False
 903    ):
 904    """
 905    Generates a batch of paired high- and low-resolution image patches for training.
 906
 907    This function creates training data by taking a list of high-resolution source
 908    images, extracting random patches, and then downsampling them to create
 909    low-resolution counterparts. This provides the (input, ground_truth) pairs
 910    needed to train a super-resolution model.
 911
 912    Parameters
 913    ----------
 914    filenames : list of str
 915        A list of file paths to the high-resolution source images.
 916
 917    target_patch_size : tuple or list of int
 918        The dimensions of the high-resolution (ground truth) patch to extract,
 919        e.g., `(128, 128, 128)`.
 920
 921    target_patch_size_low : tuple or list of int
 922        The dimensions of the low-resolution (input) patch to generate. The ratio
 923        between `target_patch_size` and `target_patch_size_low` determines the
 924        super-resolution factor.
 925
 926    nPatches : int, optional
 927        The number of patch pairs to generate in this batch. Default is 128.
 928
 929    istest : bool, optional
 930        If True, the function also generates a third output array containing patches
 931        that have been naively upsampled using linear interpolation. This is useful
 932        for calculating baseline evaluation metrics (e.g., PSNR) against which the
 933        model's performance can be compared. Default is False.
 934
 935    patch_scaler : bool, optional
 936        If True, scales the intensity of each high-resolution patch to the [0, 1]
 937        range before creating the downsampled version. This can help with
 938        training stability. Default is True.
 939
 940    to_tensorflow : bool, optional
 941        If True, casts the output NumPy arrays to TensorFlow tensors. Default is False.
 942
 943    verbose : bool, optional
 944        If True, prints progress messages during patch generation. Default is False.
 945
 946    Returns
 947    -------
 948    tuple
 949        A tuple of NumPy arrays or TensorFlow tensors.
 950        - If `istest` is False: `(patchesResam, patchesOrig)`
 951            - `patchesResam`: The batch of low-resolution input patches (X_train).
 952            - `patchesOrig`: The batch of high-resolution ground truth patches (y_train).
 953        - If `istest` is True: `(patchesResam, patchesOrig, patchesUp)`
 954            - `patchesUp`: The batch of baseline, linearly-upsampled patches.
 955    """
 956    if verbose:
 957        print("begin image_patch_training_data_from_filenames")
 958    tardim = len( target_patch_size )
 959    strider = []
 960    for j in range( tardim ):
 961        strider.append( np.round( target_patch_size[j]/target_patch_size_low[j]) )
 962    if tardim == 3:
 963        shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],target_patch_size[2],1)
 964        shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],target_patch_size_low[2],1)
 965    if tardim == 2:
 966        shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],1)
 967        shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],1)
 968    patchesOrig = np.zeros(shape=shaperhi)
 969    patchesResam = np.zeros(shape=shaperlo)
 970    patchesUp = None
 971    if istest:
 972        patchesUp = np.zeros(shape=patchesOrig.shape)
 973    for myn in range(nPatches):
 974            if verbose:
 975                print(myn)
 976            imgfn = random.sample( filenames, 1 )[0]
 977            if verbose:
 978                print(imgfn)
 979            img = ants.image_read( imgfn ).iMath("Normalize")
 980            if img.components > 1:
 981                img = ants.split_channels(img)[0]
 982            img = ants.crop_image( img, ants.threshold_image( img, 0.05, 1 ) )
 983            ants.set_origin( img, ants.get_center_of_mass(img) )
 984            img = ants.iMath(img,"Normalize")
 985            spc = ants.get_spacing( img )
 986            newspc = []
 987            for jj in range(len(spc)):
 988                newspc.append(spc[jj]*strider[jj])
 989            interp_type = random.choice( [0,1] )
 990            if True:
 991                imgp = get_random_patch( img, target_patch_size )
 992                imgpmin = imgp.min()
 993                if patch_scaler:
 994                    imgp = imgp - imgpmin
 995                    imgpmax = imgp.max()
 996                    if imgpmax > 0 :
 997                        imgp = imgp / imgpmax
 998                rimgp = ants.resample_image( imgp, newspc, use_voxels = False, interp_type=interp_type  )
 999                if istest:
1000                    rimgbi = ants.resample_image( rimgp, spc, use_voxels = False, interp_type=0  )
1001                if tardim == 3:
1002                    patchesOrig[myn,:,:,:,0] = imgp.numpy()
1003                    patchesResam[myn,:,:,:,0] = rimgp.numpy()
1004                    if istest:
1005                        patchesUp[myn,:,:,:,0] = rimgbi.numpy()
1006                if tardim == 2:
1007                    patchesOrig[myn,:,:,0] = imgp.numpy()
1008                    patchesResam[myn,:,:,0] = rimgp.numpy()
1009                    if istest:
1010                        patchesUp[myn,:,:,0] = rimgbi.numpy()
1011    if to_tensorflow:
1012        patchesOrig = tf.cast( patchesOrig, "float32")
1013        patchesResam = tf.cast( patchesResam, "float32")
1014    if istest:
1015        if to_tensorflow:
1016            patchesUp = tf.cast( patchesUp, "float32")
1017    return patchesResam, patchesOrig, patchesUp

Generates a batch of paired high- and low-resolution image patches for training.

This function creates training data by taking a list of high-resolution source images, extracting random patches, and then downsampling them to create low-resolution counterparts. This provides the (input, ground_truth) pairs needed to train a super-resolution model.

Parameters

filenames : list of str A list of file paths to the high-resolution source images.

target_patch_size : tuple or list of int The dimensions of the high-resolution (ground truth) patch to extract, e.g., (128, 128, 128).

target_patch_size_low : tuple or list of int The dimensions of the low-resolution (input) patch to generate. The ratio between target_patch_size and target_patch_size_low determines the super-resolution factor.

nPatches : int, optional The number of patch pairs to generate in this batch. Default is 128.

istest : bool, optional If True, the function also generates a third output array containing patches that have been naively upsampled using linear interpolation. This is useful for calculating baseline evaluation metrics (e.g., PSNR) against which the model's performance can be compared. Default is False.

patch_scaler : bool, optional If True, scales the intensity of each high-resolution patch to the [0, 1] range before creating the downsampled version. This can help with training stability. Default is True.

to_tensorflow : bool, optional If True, casts the output NumPy arrays to TensorFlow tensors. Default is False.

verbose : bool, optional If True, prints progress messages during patch generation. Default is False.

Returns

tuple A tuple of NumPy arrays or TensorFlow tensors. - If istest is False: (patchesResam, patchesOrig) - patchesResam: The batch of low-resolution input patches (X_train). - patchesOrig: The batch of high-resolution ground truth patches (y_train). - If istest is True: (patchesResam, patchesOrig, patchesUp) - patchesUp: The batch of baseline, linearly-upsampled patches.

def seg_patch_training_data_from_filenames( filenames, target_patch_size, target_patch_size_low, nPatches=128, istest=False, patch_scaler=True, to_tensorflow=False, verbose=False):
1020def seg_patch_training_data_from_filenames(
1021    filenames,
1022    target_patch_size,
1023    target_patch_size_low,
1024    nPatches = 128,
1025    istest   = False,
1026    patch_scaler=True,
1027    to_tensorflow = False,
1028    verbose = False
1029    ):
1030    """
1031    Generates a batch of paired training data containing both images and segmentations.
1032
1033    This function extends `image_patch_training_data_from_filenames` by adding a
1034    second channel to the data. For each extracted image patch, it also generates
1035    a corresponding segmentation mask using Otsu's thresholding. This is useful for
1036    training multi-task models that perform super-resolution on both an image and
1037    its associated segmentation simultaneously.
1038
1039    Parameters
1040    ----------
1041    filenames : list of str
1042        A list of file paths to the high-resolution source images.
1043
1044    target_patch_size : tuple or list of int
1045        The dimensions of the high-resolution patch, e.g., `(128, 128, 128)`.
1046
1047    target_patch_size_low : tuple or list of int
1048        The dimensions of the low-resolution input patch.
1049
1050    nPatches : int, optional
1051        The number of patch pairs to generate. Default is 128.
1052
1053    istest : bool, optional
1054        If True, also generates a third output array containing baseline upsampled
1055        intensity images (channel 0 only). Default is False.
1056
1057    patch_scaler : bool, optional
1058        If True, scales the intensity of each image patch to the [0, 1] range.
1059        Default is True.
1060
1061    to_tensorflow : bool, optional
1062        If True, casts the output NumPy arrays to TensorFlow tensors. Default is False.
1063
1064    verbose : bool, optional
1065        If True, prints progress messages. Default is False.
1066
1067    Returns
1068    -------
1069    tuple
1070        A tuple of multi-channel NumPy arrays or TensorFlow tensors. The structure
1071        is the same as `image_patch_training_data_from_filenames`, but each
1072        array has a channel dimension of 2:
1073        - Channel 0: The intensity image.
1074        - Channel 1: The binary segmentation mask.
1075    """
1076    if verbose:
1077        print("begin seg_patch_training_data_from_filenames")
1078    tardim = len( target_patch_size )
1079    strider = []
1080    nchan = 2
1081    for j in range( tardim ):
1082        strider.append( np.round( target_patch_size[j]/target_patch_size_low[j]) )
1083    if tardim == 3:
1084        shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],target_patch_size[2],nchan)
1085        shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],target_patch_size_low[2],nchan)
1086    if tardim == 2:
1087        shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],nchan)
1088        shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],nchan)
1089    patchesOrig = np.zeros(shape=shaperhi)
1090    patchesResam = np.zeros(shape=shaperlo)
1091    patchesUp = None
1092    if istest:
1093        patchesUp = np.zeros(shape=patchesOrig.shape)
1094    for myn in range(nPatches):
1095            if verbose:
1096                print(myn)
1097            imgfn = random.sample( filenames, 1 )[0]
1098            if verbose:
1099                print(imgfn)
1100            img = ants.image_read( imgfn ).iMath("Normalize")
1101            if img.components > 1:
1102                img = ants.split_channels(img)[0]
1103            img = ants.crop_image( img, ants.threshold_image( img, 0.05, 1 ) )
1104            ants.set_origin( img, ants.get_center_of_mass(img) )
1105            img = ants.iMath(img,"Normalize")
1106            spc = ants.get_spacing( img )
1107            newspc = []
1108            for jj in range(len(spc)):
1109                newspc.append(spc[jj]*strider[jj])
1110            interp_type = random.choice( [0,1] )
1111            seg_class = random.choice( [1,2] )
1112            if True:
1113                imgp = get_random_patch( img, target_patch_size )
1114                imgpmin = imgp.min()
1115                if patch_scaler:
1116                    imgp = imgp - imgpmin
1117                    imgpmax = imgp.max()
1118                    if imgpmax > 0 :
1119                        imgp = imgp / imgpmax
1120                segp = ants.threshold_image( imgp, "Otsu", 2 ).threshold_image( seg_class, seg_class )
1121                rimgp = ants.resample_image( imgp, newspc, use_voxels = False, interp_type=interp_type  )
1122                rsegp = ants.resample_image( segp, newspc, use_voxels = False, interp_type=interp_type  )
1123                if istest:
1124                    rimgbi = ants.resample_image( rimgp, spc, use_voxels = False, interp_type=0  )
1125                if tardim == 3:
1126                    patchesOrig[myn,:,:,:,0] = imgp.numpy()
1127                    patchesResam[myn,:,:,:,0] = rimgp.numpy()
1128                    patchesOrig[myn,:,:,:,1] = segp.numpy()
1129                    patchesResam[myn,:,:,:,1] = rsegp.numpy()
1130                    if istest:
1131                        patchesUp[myn,:,:,:,0] = rimgbi.numpy()
1132                if tardim == 2:
1133                    patchesOrig[myn,:,:,0] = imgp.numpy()
1134                    patchesResam[myn,:,:,0] = rimgp.numpy()
1135                    patchesOrig[myn,:,:,1] = segp.numpy()
1136                    patchesResam[myn,:,:,1] = rsegp.numpy()
1137                    if istest:
1138                        patchesUp[myn,:,:,0] = rimgbi.numpy()
1139    if to_tensorflow:
1140        patchesOrig = tf.cast( patchesOrig, "float32")
1141        patchesResam = tf.cast( patchesResam, "float32")
1142    if istest:
1143        if to_tensorflow:
1144            patchesUp = tf.cast( patchesUp, "float32")
1145    return patchesResam, patchesOrig, patchesUp

Generates a batch of paired training data containing both images and segmentations.

This function extends image_patch_training_data_from_filenames by adding a second channel to the data. For each extracted image patch, it also generates a corresponding segmentation mask using Otsu's thresholding. This is useful for training multi-task models that perform super-resolution on both an image and its associated segmentation simultaneously.

Parameters

filenames : list of str A list of file paths to the high-resolution source images.

target_patch_size : tuple or list of int The dimensions of the high-resolution patch, e.g., (128, 128, 128).

target_patch_size_low : tuple or list of int The dimensions of the low-resolution input patch.

nPatches : int, optional The number of patch pairs to generate. Default is 128.

istest : bool, optional If True, also generates a third output array containing baseline upsampled intensity images (channel 0 only). Default is False.

patch_scaler : bool, optional If True, scales the intensity of each image patch to the [0, 1] range. Default is True.

to_tensorflow : bool, optional If True, casts the output NumPy arrays to TensorFlow tensors. Default is False.

verbose : bool, optional If True, prints progress messages. Default is False.

Returns

tuple A tuple of multi-channel NumPy arrays or TensorFlow tensors. The structure is the same as image_patch_training_data_from_filenames, but each array has a channel dimension of 2: - Channel 0: The intensity image. - Channel 1: The binary segmentation mask.

def read(filename):
1147def read( filename ):
1148    """
1149    Reads an image or a NumPy array from a file.
1150
1151    This function acts as a wrapper to intelligently load data. It checks the
1152    file extension to decide whether to use `ants.image_read` for standard
1153    medical image formats (e.g., .nii.gz, .mha) or `numpy.load` for `.npy` files.
1154
1155    Parameters
1156    ----------
1157    filename : str
1158        The full path to the file to be read.
1159
1160    Returns
1161    -------
1162    ants.ANTsImage or np.ndarray
1163        The loaded data object, either as an ANTsImage or a NumPy array.
1164    """
1165    import re
1166    isnpy = len( re.sub( ".npy", "", filename ) ) != len( filename )
1167    if not isnpy:
1168        myoutput = ants.image_read( filename )
1169    else:
1170        myoutput = np.load( filename )
1171    return myoutput

Reads an image or a NumPy array from a file.

This function acts as a wrapper to intelligently load data. It checks the file extension to decide whether to use ants.image_read for standard medical image formats (e.g., .nii.gz, .mha) or numpy.load for .npy files.

Parameters

filename : str The full path to the file to be read.

Returns

ants.ANTsImage or np.ndarray The loaded data object, either as an ANTsImage or a NumPy array.

def auto_weight_loss(mdl, feature_extractor, x, y, feature=2.0, tv=0.1, verbose=True):
1174def auto_weight_loss( mdl, feature_extractor, x, y, feature=2.0, tv=0.1, verbose=True ):
1175    """
1176    Automatically compute weighting coefficients for a combined loss function
1177    based on intensity (MSE), perceptual similarity (feature), and total variation (TV).
1178
1179    Parameters
1180    ----------
1181    mdl : tf.keras.Model
1182        A trained or untrained model to evaluate predictions on input `x`.
1183
1184    feature_extractor : tf.keras.Model
1185        A model that extracts intermediate features from the input. Commonly a VGG or ResNet
1186        trained on a perceptual task.
1187
1188    x : tf.Tensor
1189        Input batch to the model.
1190
1191    y : tf.Tensor
1192        Ground truth target for `x`, typically a batch of 2D or 3D volumes.
1193
1194    feature : float, optional
1195        Weighting factor for the feature (perceptual) term in the loss. Default is 2.0.
1196
1197    tv : float, optional
1198        Weighting factor for the total variation term in the loss. Default is 0.1.
1199
1200    verbose : bool, optional
1201        If True, prints each component of the loss and its scaled value.
1202
1203    Returns
1204    -------
1205    list of float
1206        A list of computed weights in the order:
1207        `[msq_weight, feature_weight, tv_weight]`
1208
1209    Notes
1210    -----
1211    The total loss (to be used during training) can then be constructed as:
1212
1213        `L = msq_weight * MSE + feature_weight * perceptual_loss + tv_weight * TV`
1214
1215    This function is typically used to balance loss terms before training.
1216    """    
1217    y_pred = mdl( x )
1218    squared_difference = tf.square( y - y_pred)
1219    if len( y.shape ) == 5:
1220            tdim = 3
1221            myax = [1,2,3,4]
1222    if len( y.shape ) == 4:
1223            tdim = 2
1224            myax = [1,2,3]
1225    msqTerm = tf.reduce_mean(squared_difference, axis=myax)
1226    temp1 = feature_extractor(y)
1227    temp2 = feature_extractor(y_pred)
1228    feature_difference = tf.square(temp1-temp2)
1229    featureTerm = tf.reduce_mean(feature_difference, axis=myax)
1230    msqw = 10.0
1231    featw = feature * msqw * msqTerm / featureTerm
1232    mytv = tf.cast( 0.0, 'float32')
1233    if tdim == 3:
1234        for k in range( y_pred.shape[0] ): # BUG not sure why myr fails .... might be old TF version
1235            sqzd = y_pred[k,:,:,:,:]
1236            mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) )
1237    if tdim == 2:
1238        mytv = tf.reduce_mean( tf.image.total_variation( y_pred ) )
1239    tvw = tv * msqw * msqTerm / mytv
1240    if verbose :
1241        print( "MSQ: " + str( msqw * msqTerm ) )
1242        print( "Feat: " + str( featw * featureTerm ) )
1243        print( "Tv: " + str(  mytv * tvw ) )
1244    wts = [msqw,featw.numpy().mean(),tvw.numpy().mean()]
1245    return wts

Automatically compute weighting coefficients for a combined loss function based on intensity (MSE), perceptual similarity (feature), and total variation (TV).

Parameters

mdl : tf.keras.Model A trained or untrained model to evaluate predictions on input x.

feature_extractor : tf.keras.Model A model that extracts intermediate features from the input. Commonly a VGG or ResNet trained on a perceptual task.

x : tf.Tensor Input batch to the model.

y : tf.Tensor Ground truth target for x, typically a batch of 2D or 3D volumes.

feature : float, optional Weighting factor for the feature (perceptual) term in the loss. Default is 2.0.

tv : float, optional Weighting factor for the total variation term in the loss. Default is 0.1.

verbose : bool, optional If True, prints each component of the loss and its scaled value.

Returns

list of float A list of computed weights in the order: [msq_weight, feature_weight, tv_weight]

Notes

The total loss (to be used during training) can then be constructed as:

`L = msq_weight * MSE + feature_weight * perceptual_loss + tv_weight * TV`

This function is typically used to balance loss terms before training.

def auto_weight_loss_seg( mdl, feature_extractor, x, y, feature=2.0, tv=0.1, dice=0.5, verbose=True):
1247def auto_weight_loss_seg( mdl, feature_extractor, x, y, feature=2.0, tv=0.1, dice=0.5, verbose=True ):
1248    """
1249    Automatically compute weighting coefficients for a combined loss function
1250    that includes MSE, perceptual similarity, total variation, and segmentation Dice loss.
1251
1252    Parameters
1253    ----------
1254    mdl : tf.keras.Model
1255        A segmentation + super-resolution model that outputs both image and label predictions.
1256
1257    feature_extractor : tf.keras.Model
1258        Feature extractor model used to compute perceptual similarity loss.
1259
1260    x : tf.Tensor
1261        Input tensor to the model.
1262
1263    y : tf.Tensor
1264        Target tensor with two channels: [intensity_image, segmentation_label].
1265
1266    feature : float, optional
1267        Relative weight of the perceptual feature loss term. Default is 2.0.
1268
1269    tv : float, optional
1270        Relative weight of the total variation (TV) term. Default is 0.1.
1271
1272    dice : float, optional
1273        Relative weight of the Dice loss term (for segmentation agreement). Default is 0.5.
1274
1275    verbose : bool, optional
1276        If True, prints the scaled values of each component loss.
1277
1278    Returns
1279    -------
1280    list of float
1281        A list of loss term weights in the order:
1282        `[msq_weight, feature_weight, tv_weight, dice_weight]`
1283
1284    Notes
1285    -----
1286    - The input and output tensors must be shaped such that the last axis is 2:
1287      channel 0 is intensity, channel 1 is segmentation.
1288    - This is useful for dual-task networks that predict both high-res images
1289      and associated segmentation masks.
1290
1291    See Also
1292    --------
1293    binary_dice_loss : Computes Dice loss between predicted and ground-truth masks.
1294    """    
1295    y_pred = mdl( x )
1296    if len( y.shape ) == 5:
1297            tdim = 3
1298            myax = [1,2,3,4]
1299    if len( y.shape ) == 4:
1300            tdim = 2
1301            myax = [1,2,3]
1302    y_intensity = tf.split( y, 2, axis=tdim+1 )[0]
1303    y_seg = tf.split( y, 2, axis=tdim+1 )[1]
1304    y_intensity_p = tf.split( y_pred, 2, axis=tdim+1 )[0]
1305    y_seg_p = tf.split( y_pred, 2, axis=tdim+1 )[1]
1306    squared_difference = tf.square( y_intensity - y_intensity_p )
1307    msqTerm = tf.reduce_mean(squared_difference, axis=myax)
1308    temp1 = feature_extractor(y_intensity)
1309    temp2 = feature_extractor(y_intensity_p)
1310    feature_difference = tf.square(temp1-temp2)
1311    featureTerm = tf.reduce_mean(feature_difference, axis=myax)
1312    msqw = 10.0
1313    featw = feature * msqw * msqTerm / featureTerm
1314    mytv = tf.cast( 0.0, 'float32')
1315    if tdim == 3:
1316        for k in range( y_pred.shape[0] ): # BUG not sure why myr fails .... might be old TF version
1317            sqzd = y_pred[k,:,:,:,0]
1318            mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) )
1319    if tdim == 2:
1320        mytv = tf.reduce_mean( tf.image.total_variation( y_pred[:,:,:,0] ) )
1321    tvw = tv * msqw * msqTerm / mytv
1322    mydice = binary_dice_loss( y_seg, y_seg_p )
1323    mydice = tf.reduce_mean( mydice )
1324    dicew = dice * msqw * msqTerm / mydice
1325    dicewt = np.abs( dicew.numpy().mean() )
1326    if verbose :
1327        print( "MSQ: " + str( msqw * msqTerm ) )
1328        print( "Feat: " + str( featw * featureTerm ) )
1329        print( "Tv: " + str(  mytv * tvw ) )
1330        print( "Dice: " + str( mydice * dicewt ) )
1331    wts = [msqw,featw.numpy().mean(),tvw.numpy().mean(), dicewt ]
1332    return wts

Automatically compute weighting coefficients for a combined loss function that includes MSE, perceptual similarity, total variation, and segmentation Dice loss.

Parameters

mdl : tf.keras.Model A segmentation + super-resolution model that outputs both image and label predictions.

feature_extractor : tf.keras.Model Feature extractor model used to compute perceptual similarity loss.

x : tf.Tensor Input tensor to the model.

y : tf.Tensor Target tensor with two channels: [intensity_image, segmentation_label].

feature : float, optional Relative weight of the perceptual feature loss term. Default is 2.0.

tv : float, optional Relative weight of the total variation (TV) term. Default is 0.1.

dice : float, optional Relative weight of the Dice loss term (for segmentation agreement). Default is 0.5.

verbose : bool, optional If True, prints the scaled values of each component loss.

Returns

list of float A list of loss term weights in the order: [msq_weight, feature_weight, tv_weight, dice_weight]

Notes

  • The input and output tensors must be shaped such that the last axis is 2: channel 0 is intensity, channel 1 is segmentation.
  • This is useful for dual-task networks that predict both high-res images and associated segmentation masks.

See Also

binary_dice_loss : Computes Dice loss between predicted and ground-truth masks.

def numpy_generator(filenames):
1334def numpy_generator( filenames ):
1335    """
1336    A placeholder or stub for a data generator.
1337
1338    This generator yields a tuple of `None` values once and then stops. It is
1339    likely intended as a template or for debugging purposes where a generator
1340    object is required but no actual data needs to be processed.
1341
1342    Parameters
1343    ----------
1344    filenames : any
1345        An argument that is not used by the function.
1346
1347    Yields
1348    ------
1349    tuple
1350        A single tuple `(None, None, None)`.
1351    """
1352    patchesResam=patchesOrig=patchesUp=None
1353    yield (patchesResam, patchesOrig,patchesUp)

A placeholder or stub for a data generator.

This generator yields a tuple of None values once and then stops. It is likely intended as a template or for debugging purposes where a generator object is required but no actual data needs to be processed.

Parameters

filenames : any An argument that is not used by the function.

Yields

tuple A single tuple (None, None, None).

def image_generator( filenames, nPatches, target_patch_size, target_patch_size_low, patch_scaler=True, istest=False, verbose=False):
1355def image_generator(
1356    filenames,
1357    nPatches,
1358    target_patch_size,
1359    target_patch_size_low,
1360    patch_scaler=True,
1361    istest=False,
1362    verbose = False ):
1363    """
1364    Creates an infinite generator of paired image patches for model training.
1365
1366    This function continuously generates batches of low-resolution (input) and
1367    high-resolution (ground truth) image patches. It is designed to be fed
1368    directly into a Keras `model.fit()` call.
1369
1370    Parameters
1371    ----------
1372    filenames : list of str
1373        List of file paths to the high-resolution source images.
1374    nPatches : int
1375        The number of patch pairs to generate and yield in each batch.
1376    target_patch_size : tuple or list of int
1377        The dimensions of the high-resolution (ground truth) patches.
1378    target_patch_size_low : tuple or list of int
1379        The dimensions of the low-resolution (input) patches.
1380    patch_scaler : bool, optional
1381        If True, scales patch intensities to [0, 1]. Default is True.
1382    istest : bool, optional
1383        If True, the generator will also yield a third item: a baseline
1384        linearly upsampled version of the low-resolution patch for comparison.
1385        Default is False.
1386    verbose : bool, optional
1387        If True, passes verbosity to the underlying patch generation function.
1388        Default is False.
1389
1390    Yields
1391    -------
1392    tuple
1393        A tuple of TensorFlow tensors ready for training or evaluation.
1394        - If `istest` is False: `(low_res_batch, high_res_batch)`
1395        - If `istest` is True: `(low_res_batch, high_res_batch, baseline_upsampled_batch)`
1396
1397    See Also
1398    --------
1399    image_patch_training_data_from_filenames : The function that performs the
1400                                               underlying patch extraction.
1401    """
1402    while True:
1403        patchesResam, patchesOrig, patchesUp = image_patch_training_data_from_filenames(
1404            filenames,
1405            target_patch_size = target_patch_size,
1406            target_patch_size_low = target_patch_size_low,
1407            nPatches = nPatches,
1408            istest   = istest,
1409            patch_scaler=patch_scaler,
1410            to_tensorflow = True,
1411            verbose = verbose )
1412        if istest:
1413            yield (patchesResam, patchesOrig,patchesUp)
1414        yield (patchesResam, patchesOrig)

Creates an infinite generator of paired image patches for model training.

This function continuously generates batches of low-resolution (input) and high-resolution (ground truth) image patches. It is designed to be fed directly into a Keras model.fit() call.

Parameters

filenames : list of str List of file paths to the high-resolution source images. nPatches : int The number of patch pairs to generate and yield in each batch. target_patch_size : tuple or list of int The dimensions of the high-resolution (ground truth) patches. target_patch_size_low : tuple or list of int The dimensions of the low-resolution (input) patches. patch_scaler : bool, optional If True, scales patch intensities to [0, 1]. Default is True. istest : bool, optional If True, the generator will also yield a third item: a baseline linearly upsampled version of the low-resolution patch for comparison. Default is False. verbose : bool, optional If True, passes verbosity to the underlying patch generation function. Default is False.

Yields

tuple A tuple of TensorFlow tensors ready for training or evaluation. - If istest is False: (low_res_batch, high_res_batch) - If istest is True: (low_res_batch, high_res_batch, baseline_upsampled_batch)

See Also

image_patch_training_data_from_filenames : The function that performs the underlying patch extraction.

def seg_generator( filenames, nPatches, target_patch_size, target_patch_size_low, patch_scaler=True, istest=False, verbose=False):
1417def seg_generator(
1418    filenames,
1419    nPatches,
1420    target_patch_size,
1421    target_patch_size_low,
1422    patch_scaler=True,
1423    istest=False,
1424    verbose = False ):
1425    """
1426    Creates an infinite generator of paired image and segmentation patches.
1427
1428    This function continuously generates batches of multi-channel patches, where
1429    one channel is the intensity image and the other is a segmentation mask.
1430    It is designed for training multi-task super-resolution models.
1431
1432    Parameters
1433    ----------
1434    filenames : list of str
1435        List of file paths to the high-resolution source images.
1436    nPatches : int
1437        The number of patch pairs to generate and yield in each batch.
1438    target_patch_size : tuple or list of int
1439        The dimensions of the high-resolution patches.
1440    target_patch_size_low : tuple or list of int
1441        The dimensions of the low-resolution patches.
1442    patch_scaler : bool, optional
1443        If True, scales the intensity channel of patches to [0, 1]. Default is True.
1444    istest : bool, optional
1445        If True, yields an additional baseline upsampled patch for comparison.
1446        Default is False.
1447    verbose : bool, optional
1448        If True, passes verbosity to the underlying patch generation function.
1449        Default is False.
1450
1451    Yields
1452    -------
1453    tuple
1454        A tuple of multi-channel TensorFlow tensors. Each tensor has two channels:
1455        Channel 0 contains the intensity image, and Channel 1 contains the
1456        segmentation mask.
1457
1458    See Also
1459    --------
1460    seg_patch_training_data_from_filenames : The function that performs the
1461                                             underlying patch extraction.
1462    image_generator : A similar generator for intensity-only data.
1463    """
1464    while True:
1465        patchesResam, patchesOrig, patchesUp = seg_patch_training_data_from_filenames(
1466            filenames,
1467            target_patch_size = target_patch_size,
1468            target_patch_size_low = target_patch_size_low,
1469            nPatches = nPatches,
1470            istest   = istest,
1471            patch_scaler=patch_scaler,
1472            to_tensorflow = True,
1473            verbose = verbose )
1474        if istest:
1475            yield (patchesResam, patchesOrig,patchesUp)
1476        yield (patchesResam, patchesOrig)

Creates an infinite generator of paired image and segmentation patches.

This function continuously generates batches of multi-channel patches, where one channel is the intensity image and the other is a segmentation mask. It is designed for training multi-task super-resolution models.

Parameters

filenames : list of str List of file paths to the high-resolution source images. nPatches : int The number of patch pairs to generate and yield in each batch. target_patch_size : tuple or list of int The dimensions of the high-resolution patches. target_patch_size_low : tuple or list of int The dimensions of the low-resolution patches. patch_scaler : bool, optional If True, scales the intensity channel of patches to [0, 1]. Default is True. istest : bool, optional If True, yields an additional baseline upsampled patch for comparison. Default is False. verbose : bool, optional If True, passes verbosity to the underlying patch generation function. Default is False.

Yields

tuple A tuple of multi-channel TensorFlow tensors. Each tensor has two channels: Channel 0 contains the intensity image, and Channel 1 contains the segmentation mask.

See Also

seg_patch_training_data_from_filenames : The function that performs the underlying patch extraction. image_generator : A similar generator for intensity-only data.

def train( mdl, filenames_train, filenames_test, target_patch_size, target_patch_size_low, output_prefix, n_test=8, learning_rate=5e-05, feature_layer=6, feature=2, tv=0.1, max_iterations=1000, batch_size=1, save_all_best=False, feature_type='grader', check_eval_data_iteration=20, verbose=False):
1479def train(
1480    mdl,
1481    filenames_train,
1482    filenames_test,
1483    target_patch_size,
1484    target_patch_size_low,
1485    output_prefix,
1486    n_test = 8,
1487    learning_rate=5e-5,
1488    feature_layer = 6,
1489    feature = 2,
1490    tv = 0.1,
1491    max_iterations = 1000,
1492    batch_size = 1,
1493    save_all_best = False,
1494    feature_type = 'grader',
1495    check_eval_data_iteration = 20,
1496    verbose = False  ):
1497    """
1498    Orchestrates the training process for a super-resolution model.
1499
1500    This function handles the entire training loop, including setting up data
1501    generators, defining a composite loss function, automatically balancing loss
1502    weights, iteratively training the model, periodically evaluating performance,
1503    and saving the best-performing model weights.
1504
1505    Parameters
1506    ----------
1507    mdl : tf.keras.Model
1508        The Keras model to be trained.
1509    filenames_train : list of str
1510        List of file paths for the training dataset.
1511    filenames_test : list of str
1512        List of file paths for the validation/testing dataset.
1513    target_patch_size : tuple or list
1514        The dimensions of the high-resolution target patches.
1515    target_patch_size_low : tuple or list
1516        The dimensions of the low-resolution input patches.
1517    output_prefix : str
1518        A prefix for all output files (e.g., model weights, training logs).
1519    n_test : int, optional
1520        The number of validation patches to use for evaluation. Default is 8.
1521    learning_rate : float, optional
1522        The learning rate for the Adam optimizer. Default is 5e-5.
1523    feature_layer : int, optional
1524        The layer index from the feature extractor to use for perceptual loss.
1525        Default is 6.
1526    feature : float, optional
1527        The relative weight of the perceptual (feature) loss term. Default is 2.0.
1528    tv : float, optional
1529        The relative weight of the Total Variation (TV) regularization term.
1530        Default is 0.1.
1531    max_iterations : int, optional
1532        The total number of training iterations to run. Default is 1000.
1533    batch_size : int, optional
1534        The batch size for training. Note: this implementation is optimized for
1535        batch_size=1 and may need adjustment for larger batches. Default is 1.
1536    save_all_best : bool, optional
1537        If True, saves a new model file every time validation loss improves.
1538        If False, overwrites the single best model file. Default is False.
1539    feature_type : str, optional
1540        The type of feature extractor for perceptual loss. Options: 'grader',
1541        'vgg', 'vggrandom'. Default is 'grader'.
1542    check_eval_data_iteration : int, optional
1543        The frequency (in iterations) at which to run validation and save logs.
1544        Default is 20.
1545    verbose : bool, optional
1546        If True, prints detailed progress information. Default is False.
1547
1548    Returns
1549    -------
1550    pd.DataFrame
1551        A DataFrame containing the training history, with columns for training
1552        loss, validation loss, PSNR, and baseline PSNR over iterations.
1553    """
1554    colnames = ['train_loss','test_loss','best','eval_psnr','eval_psnr_lin']
1555    training_path = np.zeros( [ max_iterations, len(colnames) ] )
1556    training_weights = np.zeros( [1,3] )
1557    if verbose:
1558        print("begin get feature extractor " + feature_type)
1559    if feature_type == 'grader':
1560        feature_extractor = get_grader_feature_network( feature_layer )
1561    elif feature_type == 'vggrandom':
1562        with eager_mode():
1563            feature_extractor = pseudo_3d_vgg_features( target_patch_size, feature_layer, pretrained=False )
1564    elif feature_type == 'vgg':
1565        with eager_mode():
1566            feature_extractor = pseudo_3d_vgg_features_unbiased( target_patch_size, feature_layer )
1567    else:
1568        raise Exception("feature type does not exist")
1569    if verbose:
1570        print("begin train generator")
1571    mydatgen = image_generator(
1572        filenames_train,
1573        nPatches=1,
1574        target_patch_size=target_patch_size,
1575        target_patch_size_low=target_patch_size_low,
1576        istest=False , verbose=False)
1577    if verbose:
1578        print("begin test generator")
1579    mydatgenTest = image_generator( filenames_test, nPatches=1,
1580        target_patch_size=target_patch_size,
1581        target_patch_size_low=target_patch_size_low,
1582        istest=True, verbose=True)
1583    patchesResamTeTf, patchesOrigTeTf, patchesUpTeTf = next( mydatgenTest )
1584    if len( patchesOrigTeTf.shape ) == 5:
1585            tdim = 3
1586            myax = [1,2,3,4]
1587    if len( patchesOrigTeTf.shape ) == 4:
1588            tdim = 2
1589            myax = [1,2,3]
1590    if verbose:
1591        print("begin train generator #2 at dim: " + str( tdim))
1592    mydatgenTest = image_generator( filenames_test, nPatches=1,
1593        target_patch_size=target_patch_size,
1594        target_patch_size_low=target_patch_size_low,
1595        istest=True, verbose=True)
1596    patchesResamTeTfB, patchesOrigTeTfB, patchesUpTeTfB = next( mydatgenTest )
1597    for k in range( n_test - 1 ):
1598        mydatgenTest = image_generator( filenames_test, nPatches=1,
1599            target_patch_size=target_patch_size,
1600            target_patch_size_low=target_patch_size_low,
1601            istest=True, verbose=True)
1602        temp0, temp1, temp2 = next( mydatgenTest )
1603        patchesResamTeTfB = tf.concat( [patchesResamTeTfB,temp0],axis=0)
1604        patchesOrigTeTfB = tf.concat( [patchesOrigTeTfB,temp1],axis=0)
1605        patchesUpTeTfB = tf.concat( [patchesUpTeTfB,temp2],axis=0)
1606    if verbose:
1607        print("begin auto_weight_loss")
1608    wts_csv = output_prefix + "_training_weights.csv"
1609    if exists( wts_csv ):
1610        wtsdf = pd.read_csv( wts_csv )
1611        wts = [wtsdf['msq'][0], wtsdf['feat'][0], wtsdf['tv'][0]]
1612        if verbose:
1613            print( "preset weights:" )
1614    else:
1615        with eager_mode():
1616            wts = auto_weight_loss( mdl, feature_extractor, patchesResamTeTf, patchesOrigTeTf,
1617                feature=feature, tv=tv )
1618        for k in range(len(wts)):
1619            training_weights[0,k]=wts[k]
1620        pd.DataFrame(training_weights, columns = ["msq","feat","tv"] ).to_csv( wts_csv )
1621        if verbose:
1622            print( "automatic weights:" )
1623    if verbose:
1624        print( wts )
1625    def my_loss_6( y_true, y_pred, msqwt = wts[0], fw = wts[1], tvwt = wts[2], mybs = batch_size ):
1626        """Composite loss: MSE + Perceptual Loss + Total Variation."""
1627        squared_difference = tf.square(y_true - y_pred)
1628        if len( y_true.shape ) == 5:
1629            tdim = 3
1630            myax = [1,2,3,4]
1631        if len( y_true.shape ) == 4:
1632            tdim = 2
1633            myax = [1,2,3]
1634        msqTerm = tf.reduce_mean(squared_difference, axis=myax)
1635        temp1 = feature_extractor(y_true)
1636        temp2 = feature_extractor(y_pred)
1637        feature_difference = tf.square(temp1-temp2)
1638        featureTerm = tf.reduce_mean(feature_difference, axis=myax)
1639        loss = msqTerm * msqwt + featureTerm * fw
1640        mytv = tf.cast( 0.0, 'float32')
1641        # mybs =  int( y_pred.shape[0] ) --- should work but ... ?
1642        if tdim == 3:
1643            for k in range( mybs ): # BUG not sure why myr fails .... might be old TF version
1644                sqzd = y_pred[k,:,:,:,:]
1645                mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) ) * tvwt
1646        if tdim == 2:
1647            mytv = tf.reduce_mean( tf.image.total_variation( y_pred ) ) * tvwt
1648        return( loss + mytv )
1649    if verbose:
1650        print("begin model compilation")
1651    opt = tf.keras.optimizers.Adam( learning_rate=learning_rate )
1652    mdl.compile(optimizer=opt, loss=my_loss_6)
1653    # set up some parameters for tracking performance
1654    bestValLoss=1e12
1655    bestSSIM=0.0
1656    bestQC0 = -1000
1657    bestQC1 = -1000
1658    if verbose:
1659        print( "begin training", flush=True  )
1660    for myrs in range( max_iterations ):
1661        tracker = mdl.fit( mydatgen,  epochs=2, steps_per_epoch=4, verbose=1,
1662            validation_data=(patchesResamTeTf,patchesOrigTeTf) )
1663        training_path[myrs,0]=tracker.history['loss'][0]
1664        training_path[myrs,1]=tracker.history['val_loss'][0]
1665        training_path[myrs,2]=0
1666        print( "ntrain: " + str(myrs) + " loss " + str( tracker.history['loss'][0] ) + ' val-loss ' + str(tracker.history['val_loss'][0]), flush=True  )
1667        if myrs % check_eval_data_iteration == 0:
1668            with tf.device("/cpu:0"):
1669                myofn = output_prefix + "_best_mdl.keras"
1670                if save_all_best:
1671                    myofn = output_prefix + "_" + str(myrs)+ "_mdl.keras"
1672                tester = mdl.evaluate( patchesResamTeTfB, patchesOrigTeTfB )
1673                if ( tester < bestValLoss ):
1674                    print("MyIT " + str( myrs ) + " IS BEST!! " + str( tester ) + myofn, flush=True )
1675                    bestValLoss = tester
1676                    tf.keras.models.save_model( mdl, myofn )
1677                    training_path[myrs,2]=1
1678                pp = mdl.predict( patchesResamTeTfB, batch_size = 1 )
1679                myssimSR = tf.image.psnr( pp * 220, patchesOrigTeTfB* 220, max_val=255 )
1680                myssimSR = tf.reduce_mean( myssimSR ).numpy()
1681                myssimBI = tf.image.psnr( patchesUpTeTfB * 220, patchesOrigTeTfB* 220, max_val=255 )
1682                myssimBI = tf.reduce_mean( myssimBI ).numpy()
1683                print( myofn + " : " + "PSNR Lin: " + str( myssimBI ) + " SR: " + str( myssimSR ), flush=True  )
1684                training_path[myrs,3]=myssimSR # psnr
1685                training_path[myrs,4]=myssimBI # psnrlin
1686                pd.DataFrame(training_path, columns = colnames ).to_csv( output_prefix + "_training.csv" )
1687    training_path = pd.DataFrame(training_path, columns = colnames )
1688    return training_path

Orchestrates the training process for a super-resolution model.

This function handles the entire training loop, including setting up data generators, defining a composite loss function, automatically balancing loss weights, iteratively training the model, periodically evaluating performance, and saving the best-performing model weights.

Parameters

mdl : tf.keras.Model The Keras model to be trained. filenames_train : list of str List of file paths for the training dataset. filenames_test : list of str List of file paths for the validation/testing dataset. target_patch_size : tuple or list The dimensions of the high-resolution target patches. target_patch_size_low : tuple or list The dimensions of the low-resolution input patches. output_prefix : str A prefix for all output files (e.g., model weights, training logs). n_test : int, optional The number of validation patches to use for evaluation. Default is 8. learning_rate : float, optional The learning rate for the Adam optimizer. Default is 5e-5. feature_layer : int, optional The layer index from the feature extractor to use for perceptual loss. Default is 6. feature : float, optional The relative weight of the perceptual (feature) loss term. Default is 2.0. tv : float, optional The relative weight of the Total Variation (TV) regularization term. Default is 0.1. max_iterations : int, optional The total number of training iterations to run. Default is 1000. batch_size : int, optional The batch size for training. Note: this implementation is optimized for batch_size=1 and may need adjustment for larger batches. Default is 1. save_all_best : bool, optional If True, saves a new model file every time validation loss improves. If False, overwrites the single best model file. Default is False. feature_type : str, optional The type of feature extractor for perceptual loss. Options: 'grader', 'vgg', 'vggrandom'. Default is 'grader'. check_eval_data_iteration : int, optional The frequency (in iterations) at which to run validation and save logs. Default is 20. verbose : bool, optional If True, prints detailed progress information. Default is False.

Returns

pd.DataFrame A DataFrame containing the training history, with columns for training loss, validation loss, PSNR, and baseline PSNR over iterations.

def binary_dice_loss(y_true, y_pred):
1691def binary_dice_loss(y_true, y_pred):
1692    """
1693    Computes the Dice loss for binary segmentation tasks.
1694
1695    The Dice coefficient is a common metric for comparing the overlap of two samples.
1696    This loss function computes `1 - DiceCoefficient`, making it suitable for
1697    minimization during training. A smoothing factor is added to avoid division
1698    by zero when both the prediction and the ground truth are empty.
1699
1700    Parameters
1701    ----------
1702    y_true : tf.Tensor
1703        The ground truth binary segmentation mask. Values should be 0 or 1.
1704    y_pred : tf.Tensor
1705        The predicted binary segmentation mask, typically with values in [0, 1]
1706        from a sigmoid activation.
1707
1708    Returns
1709    -------
1710    tf.Tensor
1711        A scalar tensor representing the Dice loss. The value ranges from -1 (perfect
1712        match) to 0 (no overlap), though it's typically used as `1 - dice_coeff`
1713        or just `-dice_coeff` (as here).
1714    """
1715    smoothing_factor = 1e-4
1716    K = tf.keras.backend
1717    y_true_f = K.flatten(y_true)
1718    y_pred_f = K.flatten(y_pred)
1719    intersection = K.sum(y_true_f * y_pred_f)
1720    # This is -1 * Dice Similarity Coefficient
1721    return -1 * (2 * intersection + smoothing_factor)/(K.sum(y_true_f) +
1722            K.sum(y_pred_f) + smoothing_factor)

Computes the Dice loss for binary segmentation tasks.

The Dice coefficient is a common metric for comparing the overlap of two samples. This loss function computes 1 - DiceCoefficient, making it suitable for minimization during training. A smoothing factor is added to avoid division by zero when both the prediction and the ground truth are empty.

Parameters

y_true : tf.Tensor The ground truth binary segmentation mask. Values should be 0 or 1. y_pred : tf.Tensor The predicted binary segmentation mask, typically with values in [0, 1] from a sigmoid activation.

Returns

tf.Tensor A scalar tensor representing the Dice loss. The value ranges from -1 (perfect match) to 0 (no overlap), though it's typically used as 1 - dice_coeff or just -dice_coeff (as here).

def train_seg( mdl, filenames_train, filenames_test, target_patch_size, target_patch_size_low, output_prefix, n_test=8, learning_rate=5e-05, feature_layer=6, feature=2, tv=0.1, dice=0.5, max_iterations=1000, batch_size=1, save_all_best=False, feature_type='grader', check_eval_data_iteration=20, verbose=False):
1724def train_seg(
1725    mdl,
1726    filenames_train,
1727    filenames_test,
1728    target_patch_size,
1729    target_patch_size_low,
1730    output_prefix,
1731    n_test = 8,
1732    learning_rate=5e-5,
1733    feature_layer = 6,
1734    feature = 2,
1735    tv = 0.1,
1736    dice = 0.5,
1737    max_iterations = 1000,
1738    batch_size = 1,
1739    save_all_best = False,
1740    feature_type = 'grader',
1741    check_eval_data_iteration = 20,
1742    verbose = False  ):
1743    """
1744    Orchestrates training for a multi-task image and segmentation SR model.
1745
1746    This function extends the `train` function to handle models that predict
1747    both a super-resolved image and a super-resolved segmentation mask. It uses
1748    a four-component composite loss: MSE (for image), a perceptual loss (for
1749    image), Total Variation (for image), and Dice loss (for segmentation).
1750
1751    Parameters
1752    ----------
1753    mdl : tf.keras.Model
1754        The 2-channel Keras model to be trained.
1755    filenames_train : list of str
1756        List of file paths for the training dataset.
1757    filenames_test : list of str
1758        List of file paths for the validation/testing dataset.
1759    target_patch_size : tuple or list
1760        The dimensions of the high-resolution target patches.
1761    target_patch_size_low : tuple or list
1762        The dimensions of the low-resolution input patches.
1763    output_prefix : str
1764        A prefix for all output files.
1765    n_test : int, optional
1766        Number of validation patches for evaluation. Default is 8.
1767    learning_rate : float, optional
1768        Learning rate for the Adam optimizer. Default is 5e-5.
1769    feature_layer : int, optional
1770        Layer from the feature extractor for perceptual loss. Default is 6.
1771    feature : float, optional
1772        Relative weight of the perceptual loss term. Default is 2.0.
1773    tv : float, optional
1774        Relative weight of the Total Variation regularization term. Default is 0.1.
1775    dice : float, optional
1776        Relative weight of the Dice loss term for the segmentation mask.
1777        Default is 0.5.
1778    max_iterations : int, optional
1779        Total number of training iterations. Default is 1000.
1780    batch_size : int, optional
1781        The batch size for training. Default is 1.
1782    save_all_best : bool, optional
1783        If True, saves all models that improve validation loss. Default is False.
1784    feature_type : str, optional
1785        Type of feature extractor for perceptual loss. Default is 'grader'.
1786    check_eval_data_iteration : int, optional
1787        Frequency (in iterations) for running validation. Default is 20.
1788    verbose : bool, optional
1789        If True, prints detailed progress information. Default is False.
1790
1791    Returns
1792    -------
1793    pd.DataFrame
1794        A DataFrame containing the training history, including columns for losses
1795        and evaluation metrics like PSNR and Dice score.
1796
1797    See Also
1798    --------
1799    train : The training function for single-task (intensity-only) models.
1800    """
1801    colnames = ['train_loss','test_loss','best','eval_psnr','eval_psnr_lin','eval_msq','eval_dice']
1802    training_path = np.zeros( [ max_iterations, len(colnames) ] )
1803    training_weights = np.zeros( [1,4] )
1804    if verbose:
1805        print("begin get feature extractor")
1806    if feature_type == 'grader':
1807        feature_extractor = get_grader_feature_network( feature_layer )
1808    elif feature_type == 'vggrandom':
1809        feature_extractor = pseudo_3d_vgg_features( target_patch_size, feature_layer, pretrained=False )
1810    else:
1811        feature_extractor = pseudo_3d_vgg_features_unbiased( target_patch_size, feature_layer  )
1812    if verbose:
1813        print("begin train generator")
1814    mydatgen = seg_generator(
1815        filenames_train,
1816        nPatches=1,
1817        target_patch_size=target_patch_size,
1818        target_patch_size_low=target_patch_size_low,
1819        istest=False , verbose=False)
1820    if verbose:
1821        print("begin test generator")
1822    mydatgenTest = seg_generator( filenames_test, nPatches=1,
1823        target_patch_size=target_patch_size,
1824        target_patch_size_low=target_patch_size_low,
1825        istest=True, verbose=True)
1826    patchesResamTeTf, patchesOrigTeTf, patchesUpTeTf = next( mydatgenTest )
1827    if len( patchesOrigTeTf.shape ) == 5:
1828            tdim = 3
1829            myax = [1,2,3,4]
1830    if len( patchesOrigTeTf.shape ) == 4:
1831            tdim = 2
1832            myax = [1,2,3]
1833    if verbose:
1834        print("begin train generator #2 at dim: " + str( tdim))
1835    mydatgenTest = seg_generator( filenames_test, nPatches=1,
1836        target_patch_size=target_patch_size,
1837        target_patch_size_low=target_patch_size_low,
1838        istest=True, verbose=True)
1839    patchesResamTeTfB, patchesOrigTeTfB, patchesUpTeTfB = next( mydatgenTest )
1840    for k in range( n_test - 1 ):
1841        mydatgenTest = seg_generator( filenames_test, nPatches=1,
1842            target_patch_size=target_patch_size,
1843            target_patch_size_low=target_patch_size_low,
1844            istest=True, verbose=True)
1845        temp0, temp1, temp2 = next( mydatgenTest )
1846        patchesResamTeTfB = tf.concat( [patchesResamTeTfB,temp0],axis=0)
1847        patchesOrigTeTfB = tf.concat( [patchesOrigTeTfB,temp1],axis=0)
1848        patchesUpTeTfB = tf.concat( [patchesUpTeTfB,temp2],axis=0)
1849    if verbose:
1850        print("begin auto_weight_loss_seg")
1851    wts_csv = output_prefix + "_training_weights.csv"
1852    if exists( wts_csv ):
1853        wtsdf = pd.read_csv( wts_csv )
1854        wts = [wtsdf['msq'][0], wtsdf['feat'][0], wtsdf['tv'][0], wtsdf['dice'][0]]
1855        if verbose:
1856            print( "preset weights:" )
1857    else:
1858        wts = auto_weight_loss_seg( mdl, feature_extractor, patchesResamTeTf, patchesOrigTeTf,
1859            feature=feature, tv=tv, dice=dice )
1860        for k in range(len(wts)):
1861            training_weights[0,k]=wts[k]
1862        pd.DataFrame(training_weights, columns = ["msq","feat","tv","dice"] ).to_csv( wts_csv )
1863        if verbose:
1864            print( "automatic weights:" )
1865    if verbose:
1866        print( wts )
1867    def my_loss_6( y_true, y_pred, msqwt = wts[0], fw = wts[1], tvwt = wts[2], dicewt=wts[3], mybs = batch_size ):
1868        """Composite loss: MSE + Perceptual + TV + Dice."""
1869        if len( y_true.shape ) == 5:
1870            tdim = 3
1871            myax = [1,2,3,4]
1872        if len( y_true.shape ) == 4:
1873            tdim = 2
1874            myax = [1,2,3]
1875        y_intensity = tf.split( y_true, 2, axis=tdim+1 )[0]
1876        y_seg = tf.split( y_true, 2, axis=tdim+1 )[1]
1877        y_intensity_p = tf.split( y_pred, 2, axis=tdim+1 )[0]
1878        y_seg_p = tf.split( y_pred, 2, axis=tdim+1 )[1]
1879        squared_difference = tf.square(y_intensity - y_intensity_p)
1880        msqTerm = tf.reduce_mean(squared_difference, axis=myax)
1881        temp1 = feature_extractor(y_intensity)
1882        temp2 = feature_extractor(y_intensity_p)
1883        feature_difference = tf.square(temp1-temp2)
1884        featureTerm = tf.reduce_mean(feature_difference, axis=myax)
1885        loss = msqTerm * msqwt + featureTerm * fw
1886        mytv = tf.cast( 0.0, 'float32')
1887        if tdim == 3:
1888            for k in range( mybs ): # BUG not sure why myr fails .... might be old TF version
1889                sqzd = y_pred[k,:,:,:,0]
1890                mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) ) * tvwt
1891        if tdim == 2:
1892            mytv = tf.reduce_mean( tf.image.total_variation( y_pred[:,:,:,0] ) ) * tvwt
1893        dicer = tf.reduce_mean( dicewt * binary_dice_loss( y_seg, y_seg_p ) )
1894        return( loss + mytv + dicer )
1895    if verbose:
1896        print("begin model compilation")
1897    opt = tf.keras.optimizers.Adam( learning_rate=learning_rate )
1898    mdl.compile(optimizer=opt, loss=my_loss_6)
1899    # set up some parameters for tracking performance
1900    bestValLoss=1e12
1901    bestSSIM=0.0
1902    bestQC0 = -1000
1903    bestQC1 = -1000
1904    if verbose:
1905        print( "begin training", flush=True  )
1906    for myrs in range( max_iterations ):
1907        tracker = mdl.fit( mydatgen,  epochs=2, steps_per_epoch=4, verbose=1,
1908            validation_data=(patchesResamTeTf,patchesOrigTeTf) )
1909        training_path[myrs,0]=tracker.history['loss'][0]
1910        training_path[myrs,1]=tracker.history['val_loss'][0]
1911        training_path[myrs,2]=0
1912        print( "ntrain: " + str(myrs) + " loss " + str( tracker.history['loss'][0] ) + ' val-loss ' + str(tracker.history['val_loss'][0]), flush=True  )
1913        if myrs % check_eval_data_iteration == 0:
1914            with tf.device("/cpu:0"):
1915                myofn = output_prefix + "_best_mdl.keras"
1916                if save_all_best:
1917                    myofn = output_prefix + "_" + str(myrs)+ "_mdl.keras"
1918                tester = mdl.evaluate( patchesResamTeTfB, patchesOrigTeTfB )
1919                if ( tester < bestValLoss ):
1920                    print("MyIT " + str( myrs ) + " IS BEST!! " + str( tester ) + myofn, flush=True )
1921                    bestValLoss = tester
1922                    tf.keras.models.save_model( mdl, myofn )
1923                    training_path[myrs,2]=1
1924                pp = mdl.predict( patchesResamTeTfB, batch_size = 1 )
1925                pp = tf.split( pp, 2, axis=tdim+1 )
1926                y_orig = tf.split( patchesOrigTeTfB, 2, axis=tdim+1 )
1927                y_up = tf.split( patchesUpTeTfB, 2, axis=tdim+1 )
1928                myssimSR = tf.image.psnr( pp[0] * 220, y_orig[0]* 220, max_val=255 )
1929                myssimSR = tf.reduce_mean( myssimSR ).numpy()
1930                myssimBI = tf.image.psnr( y_up[0] * 220, y_orig[0]* 220, max_val=255 )
1931                myssimBI = tf.reduce_mean( myssimBI ).numpy()
1932                squared_difference = tf.square(y_orig[0] - pp[0])
1933                msqTerm = tf.reduce_mean(squared_difference).numpy()
1934                dicer = binary_dice_loss( y_orig[1], pp[1] )
1935                dicer = tf.reduce_mean( dicer ).numpy()
1936                print( myofn + " : " + "PSNR Lin: " + str( myssimBI ) + " SR: " + str( myssimSR ) + " MSQ: " + str(msqTerm) + " DICE: " + str(dicer), flush=True  )
1937                training_path[myrs,3]=myssimSR # psnr
1938                training_path[myrs,4]=myssimBI # psnrlin
1939                training_path[myrs,5]=msqTerm # msq
1940                training_path[myrs,6]=dicer # dice
1941                pd.DataFrame(training_path, columns = colnames ).to_csv( output_prefix + "_training.csv" )
1942    training_path = pd.DataFrame(training_path, columns = colnames )
1943    return training_path

Orchestrates training for a multi-task image and segmentation SR model.

This function extends the train function to handle models that predict both a super-resolved image and a super-resolved segmentation mask. It uses a four-component composite loss: MSE (for image), a perceptual loss (for image), Total Variation (for image), and Dice loss (for segmentation).

Parameters

mdl : tf.keras.Model The 2-channel Keras model to be trained. filenames_train : list of str List of file paths for the training dataset. filenames_test : list of str List of file paths for the validation/testing dataset. target_patch_size : tuple or list The dimensions of the high-resolution target patches. target_patch_size_low : tuple or list The dimensions of the low-resolution input patches. output_prefix : str A prefix for all output files. n_test : int, optional Number of validation patches for evaluation. Default is 8. learning_rate : float, optional Learning rate for the Adam optimizer. Default is 5e-5. feature_layer : int, optional Layer from the feature extractor for perceptual loss. Default is 6. feature : float, optional Relative weight of the perceptual loss term. Default is 2.0. tv : float, optional Relative weight of the Total Variation regularization term. Default is 0.1. dice : float, optional Relative weight of the Dice loss term for the segmentation mask. Default is 0.5. max_iterations : int, optional Total number of training iterations. Default is 1000. batch_size : int, optional The batch size for training. Default is 1. save_all_best : bool, optional If True, saves all models that improve validation loss. Default is False. feature_type : str, optional Type of feature extractor for perceptual loss. Default is 'grader'. check_eval_data_iteration : int, optional Frequency (in iterations) for running validation. Default is 20. verbose : bool, optional If True, prints detailed progress information. Default is False.

Returns

pd.DataFrame A DataFrame containing the training history, including columns for losses and evaluation metrics like PSNR and Dice score.

See Also

train : The training function for single-task (intensity-only) models.

def read_srmodel(srfilename, custom_objects=None):
1946def read_srmodel(srfilename, custom_objects=None):
1947    """
1948    Load a super-resolution model (h5, .keras, or SavedModel format),
1949    and determine its upsampling factor.
1950
1951    Parameters
1952    ----------
1953    srfilename : str
1954        Path to the model file (.h5, .keras, or a SavedModel folder).
1955    custom_objects : dict, optional
1956        Dictionary of custom objects used in the model (e.g. {'TFOpLambda': tf.keras.layers.Lambda(...)})
1957
1958    Returns
1959    -------
1960    model : tf.keras.Model
1961        The loaded model.
1962    upsampling_factor : list of int
1963        List describing the upsampling factor:
1964        - For 3D input: [x_up, y_up, z_up, channels]
1965        - For 2D input: [x_up, y_up, channels]
1966
1967    Example
1968    -------
1969    >>> mdl, up = read_srmodel("mymodel.keras")
1970    >>> mdl, up = read_srmodel("my_weights.h5", custom_objects={"TFOpLambda": tf.keras.layers.Lambda(tf.identity)})
1971    """
1972
1973    # Expand path and detect format
1974    srfilename = os.path.expanduser(srfilename)
1975    ext = os.path.splitext(srfilename)[1].lower()
1976
1977    if os.path.isdir(srfilename):
1978        # SavedModel directory
1979        model = tf.keras.models.load_model(srfilename, custom_objects=custom_objects, compile=False)
1980    elif ext in ['.h5', '.keras']:
1981        model = tf.keras.models.load_model(srfilename, custom_objects=custom_objects, compile=False)
1982    else:
1983        raise ValueError(f"Unsupported model format: {ext}")
1984
1985    # Determine channel index
1986    input_shape = model.input_shape
1987    if isinstance(input_shape, list):
1988        input_shape = input_shape[0]
1989    chanindex = 3 if len(input_shape) == 4 else 4
1990    nchan = int(input_shape[chanindex])
1991
1992    # Run dummy input to compute upsampling factor
1993    try:
1994        if len(input_shape) == 5:  # 3D
1995            dummy_input = np.zeros([1, 8, 8, 8, nchan])
1996        else:  # 2D
1997            dummy_input = np.zeros([1, 8, 8, nchan])
1998
1999        # Handle named inputs if necessary
2000        try:
2001            output = model(dummy_input)
2002        except Exception:
2003            output = model({model.input_names[0]: dummy_input})
2004
2005        outshp = output.shape
2006        if len(input_shape) == 5:
2007            return model, [int(outshp[1]/8), int(outshp[2]/8), int(outshp[3]/8), nchan]
2008        else:
2009            return model, [int(outshp[1]/8), int(outshp[2]/8), nchan]
2010
2011    except Exception as e:
2012        raise RuntimeError(f"Could not infer upsampling factor. Error: {e}")

Load a super-resolution model (h5, .keras, or SavedModel format), and determine its upsampling factor.

Parameters

srfilename : str Path to the model file (.h5, .keras, or a SavedModel folder). custom_objects : dict, optional Dictionary of custom objects used in the model (e.g. {'TFOpLambda': tf.keras.layers.Lambda(...)})

Returns

model : tf.keras.Model The loaded model. upsampling_factor : list of int List describing the upsampling factor: - For 3D input: [x_up, y_up, z_up, channels] - For 2D input: [x_up, y_up, channels]

Example

>>> mdl, up = read_srmodel("mymodel.keras")
>>> mdl, up = read_srmodel("my_weights.h5", custom_objects={"TFOpLambda": tf.keras.layers.Lambda(tf.identity)})
def simulate_image(shaper=[32, 32, 32], n_levels=10, multiply=False):
2015def simulate_image( shaper=[32,32,32], n_levels=10, multiply=False ):
2016    """
2017    generate an image of given shape and number of levels
2018
2019    Arguments
2020    ---------
2021    shaper : [x,y,z] or [x,y]
2022
2023    n_levels : int
2024
2025    multiply : boolean
2026
2027    Returns
2028    -------
2029
2030    ants.image
2031
2032    """
2033    img = ants.from_numpy( np.random.normal( 0, 1.0, size=shaper ) ) * 0
2034    for k in range(n_levels):
2035        temp = ants.from_numpy( np.random.normal( 0, 1.0, size=shaper ) )
2036        temp = ants.smooth_image( temp, n_levels )
2037        temp = ants.threshold_image( temp, "Otsu", 1 )
2038        if multiply:
2039            temp = temp * k
2040        img = img + temp
2041    return img

generate an image of given shape and number of levels

Arguments

shaper : [x,y,z] or [x,y]

n_levels : int

multiply : boolean

Returns

ants.image

def optimize_upsampling_shape(spacing, modality='T1', roundit=False, verbose=False):
2044def optimize_upsampling_shape( spacing, modality='T1', roundit=False, verbose=False ):
2045    """
2046    Compute the optimal upsampling shape string (e.g., '2x2x2') based on image voxel spacing
2047    and imaging modality. This output is used to select an appropriate pretrained 
2048    super-resolution model filename.
2049
2050    Parameters
2051    ----------
2052    spacing : sequence of float
2053        Voxel spacing (physical size per voxel in mm) from the input image.
2054        Typically obtained from `ants.get_spacing(image)`.
2055
2056    modality : str, optional
2057        Imaging modality. Affects resolution thresholds:
2058        - 'T1' : anatomical MRI (default minimum spacing: 0.35 mm)
2059        - 'DTI' : diffusion MRI (default minimum spacing: 1.0 mm)
2060        - 'NM' : nuclear medicine (e.g., PET/SPECT, minimum spacing: 0.25 mm)
2061
2062    roundit : bool, optional
2063        If True, uses rounded integer ratios for the upsampling shape.
2064        Otherwise, uses floor division with constraints.
2065
2066    verbose : bool, optional
2067        If True, prints detailed internal values and logic.
2068
2069    Returns
2070    -------
2071    str
2072        Optimal upsampling shape string in the form 'AxBxC',
2073        e.g., '2x2x2', '4x4x2'.
2074
2075    Notes
2076    -----
2077    - The function prevents upsampling ratios that would result in '1x1x1'
2078      by defaulting to '2x2x2'.
2079    - It also avoids uncommon ratios like '5' by rounding to the nearest valid option.
2080    - The returned string is commonly used to populate a model filename template:
2081      
2082      Example:
2083          >>> bestup = optimize_upsampling_shape(ants.get_spacing(t1_img), modality='T1')
2084          >>> model = re.sub('bestup', bestup, 'siq_smallshort_train_bestup_1chan.keras')
2085    """
2086    minspc = min( list( spacing ) )
2087    maxspc = max( list( spacing ) )
2088    ratio = maxspc/minspc
2089    if ratio == 1.0:
2090        ratio = 0.5
2091    roundratio = np.round( ratio )
2092    tarshaperaw = []
2093    tarshape = []
2094    tarshaperound = []
2095    for k in range( len( spacing ) ):
2096        locrat = spacing[k]/minspc
2097        newspc = spacing[k] * roundratio
2098        tarshaperaw.append( locrat )
2099        if modality == "NM":
2100            if verbose:
2101                print("Using minspacing: 0.25")
2102            if newspc < 0.25 :
2103                locrat = spacing[k]/0.25
2104        elif modality == "DTI":
2105            if verbose:
2106                print("Using minspacing: 1.0")
2107            if newspc < 1.0 :
2108                locrat = spacing[k]/1.0
2109        else: # assume T1
2110            if verbose:
2111                print("Using minspacing: 0.35")
2112            if newspc < 0.35 :
2113                locrat = spacing[k]/0.35
2114        myint = int( locrat )
2115        if ( myint == 0 ):
2116            myint = 1
2117        if myint == 5:
2118            myint = 4
2119        if ( myint > 6 ):
2120            myint = 6
2121        tarshape.append( str( myint ) )
2122        tarshaperound.append( str( int(np.round( locrat )) ) )
2123    if verbose:
2124        print("before emendation:")
2125        print( tarshaperaw )
2126        print( tarshaperound )
2127        print( tarshape )
2128    allone = True
2129    if roundit:
2130        tarshape = tarshaperound
2131    for k in range( len( tarshape ) ):
2132        if tarshape[k] != "1":
2133            allone=False
2134    if allone:
2135        tarshape = ["2","2","2"] # default
2136    return "x".join(tarshape)

Compute the optimal upsampling shape string (e.g., '2x2x2') based on image voxel spacing and imaging modality. This output is used to select an appropriate pretrained super-resolution model filename.

Parameters

spacing : sequence of float Voxel spacing (physical size per voxel in mm) from the input image. Typically obtained from ants.get_spacing(image).

modality : str, optional Imaging modality. Affects resolution thresholds: - 'T1' : anatomical MRI (default minimum spacing: 0.35 mm) - 'DTI' : diffusion MRI (default minimum spacing: 1.0 mm) - 'NM' : nuclear medicine (e.g., PET/SPECT, minimum spacing: 0.25 mm)

roundit : bool, optional If True, uses rounded integer ratios for the upsampling shape. Otherwise, uses floor division with constraints.

verbose : bool, optional If True, prints detailed internal values and logic.

Returns

str Optimal upsampling shape string in the form 'AxBxC', e.g., '2x2x2', '4x4x2'.

Notes

  • The function prevents upsampling ratios that would result in '1x1x1' by defaulting to '2x2x2'.
  • It also avoids uncommon ratios like '5' by rounding to the nearest valid option.
  • The returned string is commonly used to populate a model filename template:

    Example:

    bestup = optimize_upsampling_shape(ants.get_spacing(t1_img), modality='T1') model = re.sub('bestup', bestup, 'siq_smallshort_train_bestup_1chan.keras')

def compare_models( model_filenames, img, n_classes=3, poly_order='hist', identifier=None, noise_sd=0.1, verbose=False):
2138def compare_models( model_filenames, img, n_classes=3,
2139    poly_order='hist',
2140    identifier=None, noise_sd=0.1,verbose=False ):
2141    """
2142    Evaluates and compares the performance of multiple super-resolution models on a given image.
2143
2144    This function provides a standardized way to benchmark SR models. For each model,
2145    it performs the following steps:
2146    1. Loads the model and determines its upsampling factor.
2147    2. Downsamples the high-resolution input image (`img`) to create a low-resolution
2148       input, simulating a real-world scenario.
2149    3. Adds Gaussian noise to the low-resolution input to test for robustness.
2150    4. Runs inference using the model to generate a super-resolved output.
2151    5. Generates a baseline output by upsampling the low-res input with linear interpolation.
2152    6. Calculates PSNR and SSIM metrics comparing both the model's output and the
2153       baseline against the original high-resolution image.
2154    7. If a dual-channel (image + segmentation) model is detected, it also calculates
2155       Dice scores for segmentation performance.
2156    8. Aggregates all results into a pandas DataFrame for easy comparison.
2157
2158    Parameters
2159    ----------
2160    model_filenames : list of str
2161        A list of file paths to the Keras models (.h5, .keras) to be compared.
2162    img : ants.ANTsImage
2163        The high-resolution ground truth image. This image will be downsampled to
2164        create the input for the models.
2165    n_classes : int, optional
2166        The number of classes for Otsu's thresholding when auto-generating a
2167        segmentation for evaluating dual-channel models. Default is 3.
2168    poly_order : str or int, optional
2169        Method for intensity matching between the SR output and the reference.
2170        Options: 'hist' for histogram matching (default), an integer for
2171        polynomial regression, or None to disable.
2172    identifier : str, optional
2173        A custom identifier for the output DataFrame. If None, it is inferred
2174        from the model filename. Default is None.
2175    noise_sd : float, optional
2176        Standard deviation of the additive Gaussian noise applied to the
2177        downsampled image before inference. Default is 0.1.
2178    verbose : bool, optional
2179        If True, prints detailed progress and intermediate values. Default is False.
2180
2181    Returns
2182    -------
2183    pd.DataFrame
2184        A DataFrame where each row corresponds to a model. Columns contain evaluation
2185        metrics (PSNR.SR, SSIM.SR, DICE.SR), baseline metrics (PSNR.LIN, SSIM.LIN,
2186        DICE.NN), and metadata.
2187
2188    Notes
2189    -----
2190    When evaluating a 2-channel (segmentation) model, the primary metric for the
2191    segmentation task is the Dice score (`DICE.SR`). The intensity metrics (PSNR, SSIM)
2192    are still computed on the first channel.
2193    """
2194    padding=4
2195    mydf = pd.DataFrame()
2196    for k in range( len( model_filenames ) ):
2197        srmdl, upshape = read_srmodel( model_filenames[k] )
2198        if verbose:
2199            print( model_filenames[k] )
2200            print( upshape )
2201        tarshape = []
2202        inspc = ants.get_spacing(img)
2203        for j in range(len(img.shape)):
2204            tarshape.append( float(upshape[j]) * inspc[j] )
2205        # uses linear interp
2206        dimg=ants.resample_image( img, tarshape, use_voxels=False, interp_type=0 )
2207        dimg = ants.add_noise_to_image( dimg,'additivegaussian', [0,noise_sd] )
2208        import math
2209        dicesr=math.nan
2210        dicenn=math.nan
2211        if upshape[3] == 2:
2212            seghigh = ants.threshold_image( img,"Otsu",n_classes)
2213            seglow = ants.resample_image( seghigh, tarshape, use_voxels=False, interp_type=1 )
2214            dimgup=inference( dimg, srmdl, segmentation = seglow, poly_order=poly_order, verbose=verbose )
2215            dimgupseg = dimgup['super_resolution_segmentation']
2216            dimgup = dimgup['super_resolution']
2217            segblock = ants.resample_image_to_target( seghigh, dimgupseg, interp_type='nearestNeighbor'  )
2218            segimgnn = ants.resample_image_to_target( seglow, dimgupseg, interp_type='nearestNeighbor' )
2219            segblock[ dimgupseg == 0 ] = 0
2220            segimgnn[ dimgupseg == 0 ] = 0
2221            dicenn = ants.label_overlap_measures(segblock, segimgnn)['MeanOverlap'][0]
2222            dicesr = ants.label_overlap_measures(segblock, dimgupseg)['MeanOverlap'][0]
2223        else:
2224            dimgup=inference( dimg, srmdl, poly_order=poly_order, verbose=verbose )
2225        dimglin = ants.resample_image_to_target( dimg, dimgup, interp_type='linear' )
2226        imgblock = ants.resample_image_to_target( img, dimgup, interp_type='linear'  )
2227        dimgup[ imgblock == 0.0 ]=0.0
2228        dimglin[ imgblock == 0.0 ]=0.0
2229        padder = []
2230        dimwarning=False
2231        for jj in range(img.dimension):
2232            padder.append( padding )
2233            if img.shape[jj] != imgblock.shape[jj]:
2234                dimwarning=True
2235        if dimwarning:
2236            print("NOTE: dimensions of downsampled to upsampled image do not match!!!")
2237            print("we force them to match but this suggests results may not be reliable.")
2238        temp = os.path.basename( model_filenames[k] )
2239        temp = re.sub( "siq_default_sisr_", "", temp )
2240        temp = re.sub( "_best_mdl.keras", "", temp )
2241        temp = re.sub( "_best_mdl.h5", "", temp )
2242        if verbose and dimwarning:
2243            print( "original img shape" )
2244            print( img.shape )
2245            print( "resampled img shape" )
2246            print( imgblock.shape )
2247        a=[]
2248        imgshape = []
2249        for aa in range(len(upshape)):
2250            a.append( str(upshape[aa]) )
2251            if aa < len(imgblock.shape):
2252                imgshape.append( str( imgblock.shape[aa] ) )
2253        if identifier is None:
2254            identifier=temp
2255        mydict = {
2256            "identifier":identifier,
2257            "imgshape":"x".join(imgshape),
2258            "mdl": temp,
2259            "mdlshape":"x".join(a),
2260            "PSNR.LIN": antspynet.psnr( imgblock, dimglin ),
2261            "PSNR.SR": antspynet.psnr( imgblock, dimgup ),
2262            "SSIM.LIN": antspynet.ssim( imgblock, dimglin ),
2263            "SSIM.SR": antspynet.ssim( imgblock, dimgup ),
2264            "DICE.NN": dicenn,
2265            "DICE.SR": dicesr,
2266            "dimwarning": dimwarning }
2267        if verbose:
2268            print( mydict )
2269        temp = pd.DataFrame.from_records( [mydict], index=[0] )
2270        mydf = pd.concat( [mydf,temp], axis=0 )
2271        # end loop
2272    return mydf

Evaluates and compares the performance of multiple super-resolution models on a given image.

This function provides a standardized way to benchmark SR models. For each model, it performs the following steps:

  1. Loads the model and determines its upsampling factor.
  2. Downsamples the high-resolution input image (img) to create a low-resolution input, simulating a real-world scenario.
  3. Adds Gaussian noise to the low-resolution input to test for robustness.
  4. Runs inference using the model to generate a super-resolved output.
  5. Generates a baseline output by upsampling the low-res input with linear interpolation.
  6. Calculates PSNR and SSIM metrics comparing both the model's output and the baseline against the original high-resolution image.
  7. If a dual-channel (image + segmentation) model is detected, it also calculates Dice scores for segmentation performance.
  8. Aggregates all results into a pandas DataFrame for easy comparison.

Parameters

model_filenames : list of str A list of file paths to the Keras models (.h5, .keras) to be compared. img : ants.ANTsImage The high-resolution ground truth image. This image will be downsampled to create the input for the models. n_classes : int, optional The number of classes for Otsu's thresholding when auto-generating a segmentation for evaluating dual-channel models. Default is 3. poly_order : str or int, optional Method for intensity matching between the SR output and the reference. Options: 'hist' for histogram matching (default), an integer for polynomial regression, or None to disable. identifier : str, optional A custom identifier for the output DataFrame. If None, it is inferred from the model filename. Default is None. noise_sd : float, optional Standard deviation of the additive Gaussian noise applied to the downsampled image before inference. Default is 0.1. verbose : bool, optional If True, prints detailed progress and intermediate values. Default is False.

Returns

pd.DataFrame A DataFrame where each row corresponds to a model. Columns contain evaluation metrics (PSNR.SR, SSIM.SR, DICE.SR), baseline metrics (PSNR.LIN, SSIM.LIN, DICE.NN), and metadata.

Notes

When evaluating a 2-channel (segmentation) model, the primary metric for the segmentation task is the Dice score (DICE.SR). The intensity metrics (PSNR, SSIM) are still computed on the first channel.

def region_wise_super_resolution(image, mask, super_res_model, dilation_amount=4, verbose=False):
2277def region_wise_super_resolution(image, mask, super_res_model, dilation_amount=4, verbose=False):
2278    """
2279    Apply super-resolution model to each labeled region in the mask independently.
2280
2281    Arguments
2282    ---------
2283    image : ANTsImage
2284        Input image.
2285
2286    mask : ANTsImage
2287        Integer-labeled segmentation mask with non-zero regions to upsample.
2288
2289    super_res_model : tf.keras.Model
2290        Trained super-resolution model.
2291
2292    dilation_amount : int
2293        Number of morphological dilations applied to each label region before cropping.
2294
2295    verbose : bool
2296        If True, print detailed status.
2297
2298    Returns
2299    -------
2300    ANTsImage : Full-size super-resolved image with per-label inference and stitching.
2301    """
2302    import ants
2303    import numpy as np
2304    from antspynet import apply_super_resolution_model_to_image
2305
2306    upFactor = []
2307    input_shape = super_res_model.inputs[0].shape
2308    test_shape = [1, 8, 8, 1] if len(input_shape) == 4 else [1, 8, 8, 8, 1]
2309    test_input = np.zeros(test_shape, dtype=np.float32)
2310    test_output = super_res_model(test_input)
2311
2312    for k in range(len(test_shape) - 2):  # ignore batch + channel
2313        upFactor.append(int(test_output.shape[k + 1] / test_input.shape[k + 1]))
2314
2315    original_size = mask.shape  # e.g., (x, y, z)
2316    new_size = tuple(int(s * f) for s, f in zip(original_size, upFactor))
2317    upsampled_mask = ants.resample_image(mask, new_size, use_voxels=True, interp_type=1)
2318    upsampled_image = ants.resample_image(image, new_size, use_voxels=True, interp_type=0)
2319
2320    unique_labels = list(np.unique(upsampled_mask.numpy()))
2321    if 0 in unique_labels:
2322        unique_labels.remove(0)
2323
2324    outimg = ants.image_clone(upsampled_image)
2325
2326    for lab in unique_labels:
2327        if verbose:
2328            print(f"Processing label: {lab}")
2329        regionmask = ants.threshold_image(mask, lab, lab).iMath("MD", dilation_amount)
2330        cropped = ants.crop_image(image, regionmask)
2331        if cropped.shape[0] == 0:
2332            continue
2333        subimgsr = apply_super_resolution_model_to_image(
2334            cropped, super_res_model, target_range=[0, 1], verbose=verbose
2335        )
2336        stitched = ants.decrop_image(subimgsr, outimg)
2337        outimg[upsampled_mask == lab] = stitched[upsampled_mask == lab]
2338
2339    return outimg

Apply super-resolution model to each labeled region in the mask independently.

Arguments

image : ANTsImage Input image.

mask : ANTsImage Integer-labeled segmentation mask with non-zero regions to upsample.

super_res_model : tf.keras.Model Trained super-resolution model.

dilation_amount : int Number of morphological dilations applied to each label region before cropping.

verbose : bool If True, print detailed status.

Returns

ANTsImage : Full-size super-resolved image with per-label inference and stitching.

def region_wise_super_resolution_blended(image, mask, super_res_model, dilation_amount=4, verbose=False):
2342def region_wise_super_resolution_blended(image, mask, super_res_model, dilation_amount=4, verbose=False):
2343    """
2344    Apply super-resolution model to labeled regions with smooth blending to minimize stitching artifacts.
2345
2346    This version uses a weighted-averaging scheme based on distance transforms
2347    to create seamless transitions between super-resolved regions and the background.
2348
2349    Arguments
2350    ---------
2351    image : ANTsImage
2352        Input low-resolution image.
2353
2354    mask : ANTsImage
2355        Integer-labeled segmentation mask.
2356
2357    super_res_model : tf.keras.Model
2358        Trained super-resolution model.
2359
2360    dilation_amount : int
2361        Number of morphological dilations applied to each label region before cropping.
2362        This provides context to the SR model.
2363
2364    verbose : bool
2365        If True, print detailed status.
2366
2367    Returns
2368    -------
2369    ANTsImage : Full-size, super-resolved image with seamless blending.
2370    """
2371    import ants
2372    import numpy as np
2373    from antspynet import apply_super_resolution_model_to_image
2374    epsilon32 = np.finfo(np.float32).eps
2375    normalize_weight_maps = True  # Default behavior to normalize weight maps
2376    # --- Step 1: Determine upsampling factor and prepare initial images ---
2377    upFactor = []
2378    input_shape = super_res_model.inputs[0].shape
2379    test_shape = [1, 8, 8, 1] if len(input_shape) == 4 else [1, 8, 8, 8, 1]
2380    test_input = np.zeros(test_shape, dtype=np.float32)
2381    test_output = super_res_model(test_input)
2382    for k in range(len(test_shape) - 2):
2383        upFactor.append(int(test_output.shape[k + 1] / test_input.shape[k + 1]))
2384
2385    original_size = image.shape
2386    new_size = tuple(int(s * f) for s, f in zip(original_size, upFactor))
2387
2388    # The initial upsampled image will serve as our background
2389    background_sr_image = ants.resample_image(image, new_size, use_voxels=True, interp_type=0)
2390
2391    # --- Step 2: Initialize accumulator and weight sum canvases ---
2392    # These must be float type for accumulation
2393    accumulator = ants.image_clone(background_sr_image).astype('float32') * 0.0
2394    weight_sum = ants.image_clone(accumulator)
2395
2396    unique_labels = [l for l in np.unique(mask.numpy()) if l != 0]
2397
2398    for lab in unique_labels:
2399        if verbose:
2400            print(f"Blending label: {lab}")
2401
2402        # --- Step 3: Super-resolve a dilated patch (provides context to the model) ---
2403        region_mask_dilated = ants.threshold_image(mask, lab, lab).iMath("MD", dilation_amount)
2404        cropped_lowres = ants.crop_image(image, region_mask_dilated)
2405        if cropped_lowres.shape[0] == 0:
2406            continue
2407            
2408        # Apply the model to the cropped low-res patch
2409        sr_patch = apply_super_resolution_model_to_image(
2410            cropped_lowres, super_res_model, target_range=[0, 1]
2411        )
2412        
2413        # Place the super-resolved patch back onto a full-sized canvas
2414        sr_patch_full_size = ants.decrop_image(sr_patch, accumulator)
2415
2416        # --- Step 4: Create a smooth weight map for this region ---
2417        # We use the *non-dilated* mask for the weight map to ensure a sharp focus on the target region.
2418        region_mask_original = ants.threshold_image(mask, lab, lab)
2419        
2420        # Resample the original region mask to the high-res grid
2421        weight_map = ants.resample_image(region_mask_original, new_size, use_voxels=True, interp_type=0)
2422        weight_map = ants.smooth_image(weight_map, sigma=2.0,
2423                                        sigma_in_physical_coordinates=False)
2424        if normalize_weight_maps:
2425            weight_map = ants.iMath(weight_map, "Normalize")
2426        # --- Step 5: Accumulate the weighted values and the weights themselves ---
2427        accumulator += sr_patch_full_size * weight_map
2428        weight_sum += weight_map
2429
2430    # --- Step 6: Final Combination ---
2431    # Normalize the accumulator by the total weight at each pixel
2432    weight_sum_np = weight_sum.numpy()
2433    accumulator_np = accumulator.numpy()
2434    
2435    # Create a mask of pixels where blending occurred
2436    blended_mask = weight_sum_np > 0.0 # Use a small epsilon for float safety
2437
2438    # Start with the original upsampled image as the base
2439    final_image_np = background_sr_image.numpy()
2440    
2441    # Perform the weighted average only where weights are non-zero
2442    final_image_np[blended_mask] = accumulator_np[blended_mask] / weight_sum_np[blended_mask]
2443    
2444    # Re-insert any non-blended background regions that were processed
2445    # This handles cases where regions overlap; the weighted average takes care of it.
2446    
2447    return ants.from_numpy(final_image_np, origin=background_sr_image.origin, 
2448                           spacing=background_sr_image.spacing, direction=background_sr_image.direction)

Apply super-resolution model to labeled regions with smooth blending to minimize stitching artifacts.

This version uses a weighted-averaging scheme based on distance transforms to create seamless transitions between super-resolved regions and the background.

Arguments

image : ANTsImage Input low-resolution image.

mask : ANTsImage Integer-labeled segmentation mask.

super_res_model : tf.keras.Model Trained super-resolution model.

dilation_amount : int Number of morphological dilations applied to each label region before cropping. This provides context to the SR model.

verbose : bool If True, print detailed status.

Returns

ANTsImage : Full-size, super-resolved image with seamless blending.

def inference( image, mdl, truncation=None, segmentation=None, target_range=[1, 0], poly_order='hist', dilation_amount=0, verbose=False):
2451def inference(
2452    image,
2453    mdl,
2454    truncation=None,
2455    segmentation=None,
2456    target_range=[1, 0],
2457    poly_order='hist',
2458    dilation_amount=0,
2459    verbose=False):
2460    """
2461    Perform super-resolution inference on an input image, optionally guided by segmentation.
2462
2463    This function uses a trained deep learning model to enhance the resolution of a medical image.
2464    It optionally applies label-wise inference if a segmentation mask is provided.
2465
2466    Parameters
2467    ----------
2468    image : ants.ANTsImage
2469        Input image to be super-resolved.
2470
2471    mdl : keras.Model
2472        Trained super-resolution model, typically from ANTsPyNet.
2473
2474    truncation : tuple or list of float, optional
2475        Percentile values (e.g., [0.01, 0.99]) for intensity truncation before model input.
2476        If None, no truncation is applied.
2477
2478    segmentation : ants.ANTsImage, optional
2479        A labeled segmentation mask. If provided, super-resolution is performed per label
2480        using `region_wise_super_resolution` or `super_resolution_segmentation_per_label`.
2481
2482    target_range : list of float
2483        Intensity range used for scaling the input before applying the model.
2484        Default is [1, 0] (internal default for `apply_super_resolution_model_to_image`).
2485
2486    poly_order : int, str or None
2487        Determines how to match intensity between the super-resolved image and the original.
2488        Options:
2489          - 'hist' : use histogram matching
2490          - int >= 1 : perform polynomial regression of this order
2491          - None : no intensity adjustment
2492
2493    dilation_amount : int
2494        Number of dilation steps applied to each segmentation label during
2495        region-based super-resolution (if segmentation is provided).
2496
2497    verbose : bool
2498        If True, print progress and status messages.
2499
2500    Returns
2501    -------
2502    ANTsImage or dict
2503        - If `segmentation` is None, returns a single ANTsImage (super-resolved image).
2504        - If `segmentation` is provided, returns a dictionary with:
2505            - 'super_resolution': ANTsImage
2506            - other entries may include label-wise results or metadata.
2507
2508    Examples
2509    --------
2510    >>> import ants
2511    >>> import antspynet
2512    >>> from siq import inference
2513    >>> img = ants.image_read("lowres.nii.gz")
2514    >>> model = antspynet.get_pretrained_network("dbpn", target_suffix="T1")
2515    >>> srimg = inference(img, model, truncation=[0.01, 0.99], verbose=True)
2516
2517    >>> seg = ants.image_read("mask.nii.gz")
2518    >>> sr_result = inference(img, model, segmentation=seg)
2519    >>> srimg = sr_result['super_resolution']
2520    """
2521    import ants
2522    import numpy as np
2523    import antspynet
2524    import antspyt1w
2525    from siq import region_wise_super_resolution
2526
2527    def apply_intensity_match(sr_image, reference_image, order, verbose=False):
2528        if order is None:
2529            return sr_image
2530        if verbose:
2531            print("Applying intensity match with", order)
2532        if order == 'hist':
2533            return ants.histogram_match_image(sr_image, reference_image)
2534        else:
2535            return ants.regression_match_image(sr_image, reference_image, poly_order=order)
2536
2537    pimg = ants.image_clone(image)
2538    if truncation is not None:
2539        pimg = ants.iMath(pimg, 'TruncateIntensity', truncation[0], truncation[1])
2540
2541    input_shape = mdl.inputs[0].shape
2542    num_channels = int(input_shape[-1])
2543
2544    if segmentation is not None:
2545        if num_channels == 1:
2546            if verbose:
2547                print("Using region-wise super resolution due to single-channel model with segmentation.")
2548            sr = region_wise_super_resolution_blended(
2549                pimg, segmentation, mdl,
2550                dilation_amount=dilation_amount,
2551                verbose=verbose
2552            )
2553            ref = ants.resample_image_to_target(pimg, sr)
2554            return apply_intensity_match(sr, ref, poly_order, verbose)
2555        else:
2556            mynp = segmentation.numpy()
2557            mynp = list(np.unique(mynp)[1:len(mynp)].astype(int))
2558            upFactor = []
2559            if len(input_shape) == 5:
2560                testarr = np.zeros([1, 8, 8, 8, 2])
2561                testarrout = mdl(testarr)
2562                for k in range(3):
2563                    upFactor.append(int(testarrout.shape[k + 1] / testarr.shape[k + 1]))
2564            elif len(input_shape) == 4:
2565                testarr = np.zeros([1, 8, 8, 2])
2566                testarrout = mdl(testarr)
2567                for k in range(2):
2568                    upFactor.append(int(testarrout.shape[k + 1] / testarr.shape[k + 1]))
2569            temp = antspyt1w.super_resolution_segmentation_per_label(
2570                pimg,
2571                segmentation,
2572                upFactor,
2573                mdl,
2574                segmentation_numbers=mynp,
2575                target_range=target_range,
2576                dilation_amount=dilation_amount,
2577                poly_order=poly_order,
2578                max_lab_plus_one=True
2579            )
2580            imgsr = temp['super_resolution']
2581            ref = ants.resample_image_to_target(pimg, imgsr)
2582            return apply_intensity_match(imgsr, ref, poly_order, verbose)
2583
2584    # Default path: no segmentation
2585    imgsr = antspynet.apply_super_resolution_model_to_image(
2586        pimg, mdl, target_range=target_range, regression_order=None, verbose=verbose
2587    )
2588    ref = ants.resample_image_to_target(pimg, imgsr)
2589    return apply_intensity_match(imgsr, ref, poly_order, verbose)

Perform super-resolution inference on an input image, optionally guided by segmentation.

This function uses a trained deep learning model to enhance the resolution of a medical image. It optionally applies label-wise inference if a segmentation mask is provided.

Parameters

image : ants.ANTsImage Input image to be super-resolved.

mdl : keras.Model Trained super-resolution model, typically from ANTsPyNet.

truncation : tuple or list of float, optional Percentile values (e.g., [0.01, 0.99]) for intensity truncation before model input. If None, no truncation is applied.

segmentation : ants.ANTsImage, optional A labeled segmentation mask. If provided, super-resolution is performed per label using region_wise_super_resolution or super_resolution_segmentation_per_label.

target_range : list of float Intensity range used for scaling the input before applying the model. Default is [1, 0] (internal default for apply_super_resolution_model_to_image).

poly_order : int, str or None Determines how to match intensity between the super-resolved image and the original. Options: - 'hist' : use histogram matching - int >= 1 : perform polynomial regression of this order - None : no intensity adjustment

dilation_amount : int Number of dilation steps applied to each segmentation label during region-based super-resolution (if segmentation is provided).

verbose : bool If True, print progress and status messages.

Returns

ANTsImage or dict - If segmentation is None, returns a single ANTsImage (super-resolved image). - If segmentation is provided, returns a dictionary with: - 'super_resolution': ANTsImage - other entries may include label-wise results or metadata.

Examples

>>> import ants
>>> import antspynet
>>> from siq import inference
>>> img = ants.image_read("lowres.nii.gz")
>>> model = antspynet.get_pretrained_network("dbpn", target_suffix="T1")
>>> srimg = inference(img, model, truncation=[0.01, 0.99], verbose=True)
>>> seg = ants.image_read("mask.nii.gz")
>>> sr_result = inference(img, model, segmentation=seg)
>>> srimg = sr_result['super_resolution']