siq.get_data
1from pathlib import Path 2from pathlib import PurePath 3import os 4import pandas as pd 5import math 6import os.path 7from os import path 8from os.path import exists 9import pickle 10import sys 11import numpy as np 12import random 13import functools 14from operator import mul 15from scipy.sparse.linalg import svds 16from scipy.stats import pearsonr 17import re 18 19import ants 20import antspynet 21import antspyt1w 22import tensorflow as tf 23from tensorflow.python.eager.context import eager_mode, graph_mode 24 25from multiprocessing import Pool 26 27DATA_PATH = os.path.expanduser('~/.siq/') 28 29def get_data( name=None, force_download=False, version=0, target_extension='.csv' ): 30 """ 31 Get SIQ data filename 32 33 The first time this is called, it will download data to ~/.siq. 34 After, it will just read data from disk. The ~/.siq may need to 35 be periodically deleted in order to ensure data is current. 36 37 Arguments 38 --------- 39 name : string 40 name of data tag to retrieve 41 Options: 42 - 'all' 43 44 force_download: boolean 45 46 version: version of data to download (integer) 47 48 Returns 49 ------- 50 string 51 filepath of selected data 52 53 Example 54 ------- 55 >>> import siq 56 >>> siq.get_data() 57 """ 58 os.makedirs(DATA_PATH, exist_ok=True) 59 60 def download_data( version ): 61 url = "https://figshare.com/ndownloader/articles/16912366/versions/" + str(version) 62 target_file_name = "16912366.zip" 63 target_file_name_path = tf.keras.utils.get_file(target_file_name, url, 64 cache_subdir=DATA_PATH, extract = True ) 65 os.remove( DATA_PATH + target_file_name ) 66 67 if force_download: 68 download_data( version = version ) 69 70 71 files = [] 72 for fname in os.listdir(DATA_PATH): 73 if ( fname.endswith(target_extension) ) : 74 fname = os.path.join(DATA_PATH, fname) 75 files.append(fname) 76 77 if len( files ) == 0 : 78 download_data( version = version ) 79 for fname in os.listdir(DATA_PATH): 80 if ( fname.endswith(target_extension) ) : 81 fname = os.path.join(DATA_PATH, fname) 82 files.append(fname) 83 84 if name == 'all': 85 return files 86 87 datapath = None 88 89 for fname in os.listdir(DATA_PATH): 90 mystem = (Path(fname).resolve().stem) 91 mystem = (Path(mystem).resolve().stem) 92 mystem = (Path(mystem).resolve().stem) 93 if ( name == mystem and fname.endswith(target_extension) ) : 94 datapath = os.path.join(DATA_PATH, fname) 95 96 return datapath 97 98 99 100from keras.models import Model 101from keras.layers import (Input, Add, Subtract, 102 PReLU, Concatenate, 103 UpSampling2D, UpSampling3D, 104 Conv2D, Conv2DTranspose, 105 Conv3D, Conv3DTranspose) 106 107 108# define the DBPN network - this uses a model definition that is general to<br> 109# both 2D and 3D. recommended parameters for different upsampling rates can<br> 110# be found in the papers by Haris et al. We make one significant change to<br> 111# the original architecture by allowing standard interpolation for upsampling<br> 112# instead of convolutional upsampling. this is controlled by the interpolation<br> 113# option. 114 115 116 117def dbpn(input_image_size, 118 number_of_outputs=1, 119 number_of_base_filters=64, 120 number_of_feature_filters=256, 121 number_of_back_projection_stages=7, 122 convolution_kernel_size=(12, 12), 123 strides=(8, 8), 124 last_convolution=(3, 3), 125 number_of_loss_functions=1, 126 interpolation = 'nearest' 127 ): 128 """ 129 Creates a Deep Back-Projection Network (DBPN) for single image super-resolution. 130 131 This function constructs a Keras model based on the DBPN architecture, which 132 can be configured for either 2D or 3D inputs. The network uses iterative 133 up- and down-projection blocks to refine the high-resolution image estimate. A 134 key modification from the original paper is the option to use standard 135 interpolation for upsampling instead of deconvolution layers. 136 137 Reference: 138 - Haris, M., Shakhnarovich, G., & Ukita, N. (2018). Deep Back-Projection 139 Networks For Super-Resolution. In CVPR. 140 141 Parameters 142 ---------- 143 input_image_size : tuple or list 144 The shape of the input image, including the channel. 145 e.g., `(None, None, 1)` for 2D or `(None, None, None, 1)` for 3D. 146 147 number_of_outputs : int, optional 148 The number of channels in the output image. Default is 1. 149 150 number_of_base_filters : int, optional 151 The number of filters in the up/down projection blocks. Default is 64. 152 153 number_of_feature_filters : int, optional 154 The number of filters in the initial feature extraction layer. Default is 256. 155 156 number_of_back_projection_stages : int, optional 157 The number of iterative back-projection stages (T in the paper). Default is 7. 158 159 convolution_kernel_size : tuple or list, optional 160 The kernel size for the main projection convolutions. Should match the 161 dimensionality of the input. Default is (12, 12). 162 163 strides : tuple or list, optional 164 The strides for the up/down sampling operations, defining the 165 super-resolution factor. Default is (8, 8). 166 167 last_convolution : tuple or list, optional 168 The kernel size of the final reconstruction convolution. Default is (3, 3). 169 170 number_of_loss_functions : int, optional 171 If greater than 1, the model will have multiple identical output branches. 172 Typically set to 1. Default is 1. 173 174 interpolation : str, optional 175 The interpolation method to use for upsampling layers if not using 176 transposed convolutions. 'nearest' or 'bilinear'. Default is 'nearest'. 177 178 Returns 179 ------- 180 keras.Model 181 A Keras model implementing the DBPN architecture for the specified 182 parameters. 183 """ 184 idim = len( input_image_size ) - 1 185 if idim == 2: 186 myconv = Conv2D 187 myconv_transpose = Conv2DTranspose 188 myupsampling = UpSampling2D 189 shax = ( 1, 2 ) 190 firstConv = (3,3) 191 firstStrides=(1,1) 192 smashConv=(1,1) 193 if idim == 3: 194 myconv = Conv3D 195 myconv_transpose = Conv3DTranspose 196 myupsampling = UpSampling3D 197 shax = ( 1, 2, 3 ) 198 firstConv = (3,3,3) 199 firstStrides=(1,1,1) 200 smashConv=(1,1,1) 201 def up_block_2d(L, number_of_filters=64, kernel_size=(12, 12), strides=(8, 8), 202 include_dense_convolution_layer=True): 203 if include_dense_convolution_layer == True: 204 L = myconv(filters = number_of_filters, 205 use_bias=True, 206 kernel_size=smashConv, 207 strides=firstStrides, 208 padding='same')(L) 209 L = PReLU(alpha_initializer='zero', 210 shared_axes=shax)(L) 211 # Scale up 212 if idim == 2: 213 H0 = myupsampling( size = strides, interpolation=interpolation )(L) 214 if idim == 3: 215 H0 = myupsampling( size = strides )(L) 216 H0 = myconv(filters=number_of_filters, 217 kernel_size=firstConv, 218 strides=firstStrides, 219 use_bias=True, 220 padding='same')(H0) 221 H0 = PReLU(alpha_initializer='zero', 222 shared_axes=shax)(H0) 223 # Scale down 224 L0 = myconv(filters=number_of_filters, 225 kernel_size=kernel_size, 226 strides=strides, 227 kernel_initializer='glorot_uniform', 228 padding='same')(H0) 229 L0 = PReLU(alpha_initializer='zero', 230 shared_axes=shax)(L0) 231 # Residual 232 E = Subtract()([L0, L]) 233 # Scale residual up 234 if idim == 2: 235 H1 = myupsampling( size = strides, interpolation=interpolation )(E) 236 if idim == 3: 237 H1 = myupsampling( size = strides )(E) 238 H1 = myconv(filters=number_of_filters, 239 kernel_size=firstConv, 240 strides=firstStrides, 241 use_bias=True, 242 padding='same')(H1) 243 H1 = PReLU(alpha_initializer='zero', 244 shared_axes=shax)(H1) 245 # Output feature map 246 up_block = Add()([H0, H1]) 247 return up_block 248 def down_block_2d(H, number_of_filters=64, kernel_size=(12, 12), strides=(8, 8), 249 include_dense_convolution_layer=True): 250 if include_dense_convolution_layer == True: 251 H = myconv(filters = number_of_filters, 252 use_bias=True, 253 kernel_size=smashConv, 254 strides=firstStrides, 255 padding='same')(H) 256 H = PReLU(alpha_initializer='zero', 257 shared_axes=shax)(H) 258 # Scale down 259 L0 = myconv(filters=number_of_filters, 260 kernel_size=kernel_size, 261 strides=strides, 262 kernel_initializer='glorot_uniform', 263 padding='same')(H) 264 L0 = PReLU(alpha_initializer='zero', 265 shared_axes=shax)(L0) 266 # Scale up 267 if idim == 2: 268 H0 = myupsampling( size = strides, interpolation=interpolation )(L0) 269 if idim == 3: 270 H0 = myupsampling( size = strides )(L0) 271 H0 = myconv(filters=number_of_filters, 272 kernel_size=firstConv, 273 strides=firstStrides, 274 use_bias=True, 275 padding='same')(H0) 276 H0 = PReLU(alpha_initializer='zero', 277 shared_axes=shax)(H0) 278 # Residual 279 E = Subtract()([H0, H]) 280 # Scale residual down 281 L1 = myconv(filters=number_of_filters, 282 kernel_size=kernel_size, 283 strides=strides, 284 kernel_initializer='glorot_uniform', 285 padding='same')(E) 286 L1 = PReLU(alpha_initializer='zero', 287 shared_axes=shax)(L1) 288 # Output feature map 289 down_block = Add()([L0, L1]) 290 return down_block 291 inputs = Input(shape=input_image_size) 292 # Initial feature extraction 293 model = myconv(filters=number_of_feature_filters, 294 kernel_size=firstConv, 295 strides=firstStrides, 296 padding='same', 297 kernel_initializer='glorot_uniform')(inputs) 298 model = PReLU(alpha_initializer='zero', 299 shared_axes=shax)(model) 300 # Feature smashing 301 model = myconv(filters=number_of_base_filters, 302 kernel_size=smashConv, 303 strides=firstStrides, 304 padding='same', 305 kernel_initializer='glorot_uniform')(model) 306 model = PReLU(alpha_initializer='zero', 307 shared_axes=shax)(model) 308 # Back projection 309 up_projection_blocks = [] 310 down_projection_blocks = [] 311 model = up_block_2d(model, number_of_filters=number_of_base_filters, 312 kernel_size=convolution_kernel_size, strides=strides) 313 up_projection_blocks.append(model) 314 for i in range(number_of_back_projection_stages): 315 if i == 0: 316 model = down_block_2d(model, number_of_filters=number_of_base_filters, 317 kernel_size=convolution_kernel_size, strides=strides) 318 down_projection_blocks.append(model) 319 model = up_block_2d(model, number_of_filters=number_of_base_filters, 320 kernel_size=convolution_kernel_size, strides=strides) 321 up_projection_blocks.append(model) 322 model = Concatenate()(up_projection_blocks) 323 else: 324 model = down_block_2d(model, number_of_filters=number_of_base_filters, 325 kernel_size=convolution_kernel_size, strides=strides, 326 include_dense_convolution_layer=True) 327 down_projection_blocks.append(model) 328 model = Concatenate()(down_projection_blocks) 329 model = up_block_2d(model, number_of_filters=number_of_base_filters, 330 kernel_size=convolution_kernel_size, strides=strides, 331 include_dense_convolution_layer=True) 332 up_projection_blocks.append(model) 333 model = Concatenate()(up_projection_blocks) 334 outputs = myconv(filters=number_of_outputs, 335 kernel_size=last_convolution, 336 strides=firstStrides, 337 padding = 'same', 338 kernel_initializer = "glorot_uniform")(model) 339 if number_of_loss_functions == 1: 340 deep_back_projection_network_model = Model(inputs=inputs, outputs=outputs) 341 else: 342 outputList=[] 343 for k in range(number_of_loss_functions): 344 outputList.append(outputs) 345 deep_back_projection_network_model = Model(inputs=inputs, outputs=outputList) 346 return deep_back_projection_network_model 347 348 349# generate a random corner index for a patch 350 351def get_random_base_ind( full_dims, patchWidth, off=8 ): 352 """ 353 Generates a random top-left corner index for a patch. 354 355 This utility function computes a valid starting index (e.g., [x, y, z]) 356 for extracting a patch from a larger volume, ensuring the patch fits entirely 357 within the volume's boundaries, accounting for an offset. 358 359 Parameters 360 ---------- 361 full_dims : tuple or list 362 The dimensions of the full volume (e.g., img.shape). 363 364 patchWidth : tuple or list 365 The dimensions of the patch to be extracted. 366 367 off : int, optional 368 An offset from the edge of the volume to avoid sampling near borders. 369 Default is 8. 370 371 Returns 372 ------- 373 list 374 A list of integers representing the starting coordinates for the patch. 375 """ 376 baseInd = [None,None,None] 377 for k in range(3): 378 baseInd[k]=random.sample( range( off, full_dims[k]-1-patchWidth[k] ), 1 )[0] 379 return baseInd 380 381 382# extract a random patch 383def get_random_patch( img, patchWidth ): 384 """ 385 Extracts a random patch from an image with non-zero variance. 386 387 This function repeatedly samples a random patch of a specified width from 388 the input image until it finds one where the standard deviation of pixel 389 intensities is greater than zero. This is useful for avoiding blank or 390 uniform patches during training data generation. 391 392 Parameters 393 ---------- 394 img : ants.ANTsImage 395 The source image from which to extract a patch. 396 397 patchWidth : tuple or list 398 The desired dimensions of the output patch. 399 400 Returns 401 ------- 402 ants.ANTsImage 403 A randomly extracted patch from the input image. 404 """ 405 mystd = 0 406 while mystd == 0: 407 inds = get_random_base_ind( full_dims = img.shape, patchWidth=patchWidth, off=8 ) 408 hinds = [None,None,None] 409 for k in range(len(inds)): 410 hinds[k] = inds[k] + patchWidth[k] 411 myimg = ants.crop_indices( img, inds, hinds ) 412 mystd = myimg.std() 413 return myimg 414 415def get_random_patch_pair( img, img2, patchWidth ): 416 """ 417 Extracts a corresponding random patch from a pair of images. 418 419 This function finds a single random location and extracts a patch of the 420 same size and position from two different input images. It ensures that 421 both extracted patches have non-zero variance. This is useful for creating 422 paired training data (e.g., low-res and high-res images). 423 424 Parameters 425 ---------- 426 img : ants.ANTsImage 427 The first source image. 428 429 img2 : ants.ANTsImage 430 The second source image, spatially aligned with the first. 431 432 patchWidth : tuple or list 433 The desired dimensions of the output patches. 434 435 Returns 436 ------- 437 tuple of ants.ANTsImage 438 A tuple containing two corresponding patches: (patch_from_img, patch_from_img2). 439 """ 440 mystd = mystd2 = 0 441 ct = 0 442 while mystd == 0 or mystd2 == 0: 443 inds = get_random_base_ind( full_dims = img.shape, patchWidth=patchWidth, off=8 ) 444 hinds = [None,None,None] 445 for k in range(len(inds)): 446 hinds[k] = inds[k] + patchWidth[k] 447 myimg = ants.crop_indices( img, inds, hinds ) 448 myimg2 = ants.crop_indices( img2, inds, hinds ) 449 mystd = myimg.std() 450 mystd2 = myimg2.std() 451 ct = ct + 1 452 if ( ct > 20 ): 453 return myimg, myimg2 454 return myimg, myimg2 455 456def pseudo_3d_vgg_features( inshape = [128,128,128], layer = 4, angle=0, pretrained=True, verbose=False ): 457 """ 458 Creates a pseudo-3D VGG feature extractor from a pre-trained 2D VGG model. 459 460 This function constructs a 3D VGG-style network and initializes its weights 461 by "stretching" the weights from a pre-trained 2D VGG19 model (trained on 462 ImageNet) along a specified axis. This is a technique to transfer 2D 463 perceptual knowledge to a 3D domain for tasks like perceptual loss. 464 465 Parameters 466 ---------- 467 inshape : list of int, optional 468 The input shape of the 3D volume, e.g., `[128, 128, 128]`. Default is `[128,128,128]`. 469 470 layer : int, optional 471 The block number of the VGG network from which to extract features. For 472 VGG19, this corresponds to block `layer` (e.g., layer=4 means 'block4_conv...'). 473 Default is 4. 474 475 angle : int, optional 476 The axis along which to project the 2D weights: 477 - 0: Axial plane (stretches along Z) 478 - 1: Coronal plane (stretches along Y) 479 - 2: Sagittal plane (stretches along X) 480 Default is 0. 481 482 pretrained : bool, optional 483 If True, loads the stretched ImageNet weights. If False, the model is 484 randomly initialized. Default is True. 485 486 verbose : bool, optional 487 If True, prints information about the layers being used. Default is False. 488 489 Returns 490 ------- 491 tf.keras.Model 492 A Keras model that takes a 3D volume as input and outputs the pseudo-3D 493 feature map from the specified layer and angle. 494 """ 495 def getLayerScaleFactorForTransferLearning( k, w3d, w2d ): 496 myfact = np.round( np.prod( w3d[k].shape ) / np.prod( w2d[k].shape) ) 497 return myfact 498 vgg19 = tf.keras.applications.VGG19( 499 include_top = False, weights = "imagenet", 500 input_shape = [inshape[0],inshape[1],3], 501 classes = 1000 ) 502 def findLayerIndex( layerName, mdl ): 503 for k in range( len( mdl.layers ) ): 504 if layerName == mdl.layers[k].name : 505 return k - 1 506 return None 507 layer_index = layer-1 # findLayerIndex( 'block2_conv2', vgg19 ) 508 vggmodelRaw = antspynet.create_vgg_model_3d( 509 [inshape[0],inshape[1],inshape[2],1], 510 number_of_classification_labels = 1000, 511 layers = [1, 2, 3, 4, 4], 512 lowest_resolution = 64, 513 convolution_kernel_size= (3, 3, 3), pool_size = (2, 2, 2), 514 strides = (2, 2, 2), number_of_dense_units= 4096, dropout_rate = 0, 515 style = 19, mode = "classification") 516 if verbose: 517 print( vggmodelRaw.layers[layer_index] ) 518 print( vggmodelRaw.layers[layer_index].name ) 519 print( vgg19.layers[layer_index] ) 520 print( vgg19.layers[layer_index].name ) 521 feature_extractor_2d = tf.keras.Model( 522 inputs = vgg19.input, 523 outputs = vgg19.layers[layer_index].output) 524 feature_extractor = tf.keras.Model( 525 inputs = vggmodelRaw.input, 526 outputs = vggmodelRaw.layers[layer_index].output) 527 wts_2d = feature_extractor_2d.weights 528 wts = feature_extractor.weights 529 def checkwtshape( a, b ): 530 if len(a.shape) != len(b.shape): 531 return False 532 for j in range(len(a.shape)): 533 if a.shape[j] != b.shape[j]: 534 return False 535 return True 536 for ww in range(len(wts)): 537 wts[ww]=wts[ww].numpy() 538 wts_2d[ww]=wts_2d[ww].numpy() 539 if checkwtshape( wts[ww], wts_2d[ww] ) and ww != 0: 540 wts[ww]=wts_2d[ww] 541 elif ww != 0: 542 # FIXME - should allow doing this across different angles 543 if angle == 0: 544 wts[ww][:,:,0,:,:]=wts_2d[ww]/3.0 545 wts[ww][:,:,1,:,:]=wts_2d[ww]/3.0 546 wts[ww][:,:,2,:,:]=wts_2d[ww]/3.0 547 if angle == 1: 548 wts[ww][:,0,:,:,:]=wts_2d[ww]/3.0 549 wts[ww][:,1,:,:,:]=wts_2d[ww]/3.0 550 wts[ww][:,2,:,:,:]=wts_2d[ww]/3.0 551 if angle == 2: 552 wts[ww][0,:,:,:,:]=wts_2d[ww]/3.0 553 wts[ww][1,:,:,:,:]=wts_2d[ww]/3.0 554 wts[ww][2,:,:,:,:]=wts_2d[ww]/3.0 555 else: 556 wts[ww][:,:,:,0,:]=wts_2d[ww] 557 if pretrained: 558 feature_extractor.set_weights( wts ) 559 newinput = tf.keras.layers.Rescaling( 255.0, -127.5 )( feature_extractor.input ) 560 feature_extractor2 = feature_extractor( newinput ) 561 feature_extractor = tf.keras.Model( feature_extractor.input, feature_extractor2 ) 562 return feature_extractor 563 564def pseudo_3d_vgg_features_unbiased( inshape = [128,128,128], layer = 4, verbose=False ): 565 """ 566 Create a pseudo-3D VGG-style feature extractor by aggregating axial, coronal, 567 and sagittal VGG feature representations. 568 569 This model extracts features along each principal axis using pre-trained 2D 570 VGG-style networks and concatenates them to form an unbiased pseudo-3D feature space. 571 572 Parameters 573 ---------- 574 inshape : list of int, optional 575 The input shape of the 3D volume, default is [128, 128, 128]. 576 577 layer : int, optional 578 The VGG feature layer to extract. Higher values correspond to deeper 579 layers in the pseudo-3D VGG backbone. 580 581 verbose : bool, optional 582 If True, prints debug messages during model construction. 583 584 Returns 585 ------- 586 tf.keras.Model 587 A TensorFlow Keras model that takes a 3D input volume and outputs the 588 concatenated pseudo-3D feature representation from the specified layer. 589 590 Notes 591 ----- 592 This is useful for perceptual loss or feature comparison in super-resolution 593 and image synthesis tasks. The same input is processed in three anatomical 594 planes (axial, coronal, sagittal), and features are concatenated. 595 596 See Also 597 -------- 598 pseudo_3d_vgg_features : Generates VGG features from a single anatomical plane. 599 """ 600 f = [ 601 pseudo_3d_vgg_features( inshape, layer, angle=0, pretrained=True, verbose=verbose ), 602 pseudo_3d_vgg_features( inshape, layer, angle=1, pretrained=True ), 603 pseudo_3d_vgg_features( inshape, layer, angle=2, pretrained=True ) ] 604 f1=f[0].inputs 605 f0o=f[0]( f1 ) 606 f1o=f[1]( f1 ) 607 f2o=f[2]( f1 ) 608 catter = tf.keras.layers.concatenate( [f0o, f1o, f2o ]) 609 feature_extractor = tf.keras.Model( f1, catter ) 610 return feature_extractor 611 612def get_grader_feature_network( layer=6 ): 613 """ 614 Load and extract a ResNet-based feature subnetwork for perceptual loss or quality grading. 615 616 This function loads a pre-trained 3D ResNet model ("grader") used for 617 perceptual feature extraction and returns a subnetwork that outputs activations 618 from a specified internal layer. 619 620 Parameters 621 ---------- 622 layer : int, optional 623 The index of the internal ResNet layer whose output should be used as 624 the feature representation. Default is layer 6. 625 626 Returns 627 ------- 628 tf.keras.Model 629 A Keras model that outputs features from the specified layer of the 630 pre-trained 3D ResNet grader model. 631 632 Raises 633 ------ 634 Exception 635 If the pre-trained weights file (`resnet_grader.h5`) is not found. 636 637 Notes 638 ----- 639 The pre-trained weights should be located in: `~/.antspyt1w/resnet_grader.keras` 640 641 This model is typically used to compute perceptual loss by comparing 642 intermediate activations between target and prediction volumes. 643 644 See Also 645 -------- 646 antspynet.create_resnet_model_3d : Constructs the base ResNet model. 647 """ 648 grader = antspynet.create_resnet_model_3d( 649 [None,None,None,1], 650 lowest_resolution = 32, 651 number_of_outputs = 4, 652 cardinality = 1, 653 squeeze_and_excite = False ) 654 # the folder and data below as available from antspyt1w get_data 655 graderfn = os.path.expanduser( "~/.antspyt1w/resnet_grader.h5" ) 656 if not exists( graderfn ): 657 raise Exception("graderfn " + graderfn + " does not exist") 658 grader.load_weights( graderfn) 659 # feature_extractor_23 = tf.keras.Model( inputs=grader.inputs, outputs=grader.layers[23].output ) 660 # feature_extractor_44 = tf.keras.Model( inputs=grader.inputs, outputs=grader.layers[44].output ) 661 return tf.keras.Model( inputs=grader.inputs, outputs=grader.layers[layer].output ) 662 663 664def default_dbpn( 665 strider, # length should equal dimensionality 666 dimensionality = 3, 667 nfilt=64, 668 nff = 256, 669 convn = 6, 670 lastconv = 3, 671 nbp=7, 672 nChannelsIn=1, 673 nChannelsOut=1, 674 option = None, 675 intensity_model=None, 676 segmentation_model=None, 677 sigmoid_second_channel=False, 678 clone_intensity_to_segmentation=False, 679 pro_seg = 0, 680 freeze = False, 681 verbose=False 682 ): 683 """ 684 Constructs a DBPN model based on input parameters, and can optionally 685 use external models for intensity or segmentation processing. 686 687 Args: 688 strider (list): List of strides, length must match `dimensionality`. 689 dimensionality (int): Number of dimensions (2 or 3). Default is 3. 690 nfilt (int): Number of base filters. Default is 64. 691 nff (int): Number of feature filters. Default is 256. 692 convn (int): Convolution kernel size. Default is 6. 693 lastconv (int): Size of the last convolution. Default is 3. 694 nbp (int): Number of back projection stages. Default is 7. 695 nChannelsIn (int): Number of input channels. Default is 1. 696 nChannelsOut (int): Number of output channels. Default is 1. 697 option (str): Model size option ('tiny', 'small', 'medium', 'large'). Default is None. 698 intensity_model (tf.keras.Model): Optional external intensity model. 699 segmentation_model (tf.keras.Model): Optional external segmentation model. 700 sigmoid_second_channel (bool): If True, applies sigmoid to second channel in output. 701 clone_intensity_to_segmentation (bool): If True, clones intensity model weights to segmentation. 702 pro_seg (int): If greater than 0, adds a segmentation arm. 703 freeze (bool): If True, freezes the layers of the intensity/segmentation models. 704 verbose (bool): If True, prints detailed logs. 705 706 Returns: 707 Model: A Keras model based on the specified configuration. 708 709 Raises: 710 Exception: If `len(strider)` is not equal to `dimensionality`. 711 """ 712 if option == 'tiny': 713 nfilt=32 714 nff = 64 715 convn = 3 716 lastconv = 1 717 nbp=2 718 elif option == 'small': 719 nfilt=32 720 nff = 64 721 convn = 6 722 lastconv = 3 723 nbp=4 724 elif option == 'medium': 725 nfilt=64 726 nff = 128 727 convn = 6 728 lastconv = 3 729 nbp=4 730 else: 731 option='large' 732 if verbose: 733 print("Building mode of size: " + option) 734 if intensity_model is not None: 735 print("user-passed intensity model will be frozen - only segmentation will train") 736 if segmentation_model is not None: 737 print("user-passed segmentation model will be frozen - only intensity will train") 738 739 if len(strider) != dimensionality: 740 raise Exception("len(strider) != dimensionality") 741 # **model instantiation**: these are close to defaults for the 2x network.<br> 742 # empirical evidence suggests that making covolutions and strides evenly<br> 743 # divisible by each other reduces artifacts. 2*3=6. 744 # ofn='./models/dsr3d_'+str(strider)+'up_' + str(nfilt) + '_' + str( nff ) + '_' + str(convn)+ '_' + str(lastconv)+ '_' + str(os.environ['CUDA_VISIBLE_DEVICES'])+'_v0.0.keras' 745 if dimensionality == 2 : 746 mdl = dbpn( (None,None,nChannelsIn), 747 number_of_outputs=nChannelsOut, 748 number_of_base_filters=nfilt, 749 number_of_feature_filters=nff, 750 number_of_back_projection_stages=nbp, 751 convolution_kernel_size=(convn, convn), 752 strides=(strider[0], strider[1]), 753 last_convolution=(lastconv, lastconv), 754 number_of_loss_functions=1, 755 interpolation='nearest') 756 if dimensionality == 3 : 757 mdl = dbpn( (None,None,None,nChannelsIn), 758 number_of_outputs=nChannelsOut, 759 number_of_base_filters=nfilt, 760 number_of_feature_filters=nff, 761 number_of_back_projection_stages=nbp, 762 convolution_kernel_size=(convn, convn, convn), 763 strides=(strider[0], strider[1], strider[2]), 764 last_convolution=(lastconv, lastconv, lastconv), number_of_loss_functions=1, interpolation='nearest') 765 if sigmoid_second_channel and pro_seg != 0 : 766 if dimensionality == 2 : 767 input_image_size = (None,None,2) 768 if intensity_model is None: 769 intensity_model = dbpn( (None,None,1), 770 number_of_outputs=1, 771 number_of_base_filters=nfilt, 772 number_of_feature_filters=nff, 773 number_of_back_projection_stages=nbp, 774 convolution_kernel_size=(convn, convn), 775 strides=(strider[0], strider[1]), 776 last_convolution=(lastconv, lastconv), 777 number_of_loss_functions=1, 778 interpolation='nearest') 779 else: 780 if freeze: 781 for layer in intensity_model.layers: 782 layer.trainable = False 783 if segmentation_model is None: 784 segmentation_model = dbpn( (None,None,1), 785 number_of_outputs=1, 786 number_of_base_filters=nfilt, 787 number_of_feature_filters=nff, 788 number_of_back_projection_stages=nbp, 789 convolution_kernel_size=(convn, convn), 790 strides=(strider[0], strider[1]), 791 last_convolution=(lastconv, lastconv), 792 number_of_loss_functions=1, interpolation='linear') 793 else: 794 if freeze: 795 for layer in segmentation_model.layers: 796 layer.trainable = False 797 if dimensionality == 3 : 798 input_image_size = (None,None,None,2) 799 if intensity_model is None: 800 intensity_model = dbpn( (None,None,None,1), 801 number_of_outputs=1, 802 number_of_base_filters=nfilt, 803 number_of_feature_filters=nff, 804 number_of_back_projection_stages=nbp, 805 convolution_kernel_size=(convn, convn, convn), 806 strides=(strider[0], strider[1], strider[2]), 807 last_convolution=(lastconv, lastconv, lastconv), 808 number_of_loss_functions=1, interpolation='nearest') 809 else: 810 if freeze: 811 for layer in intensity_model.layers: 812 layer.trainable = False 813 if segmentation_model is None: 814 segmentation_model = dbpn( (None,None,None,1), 815 number_of_outputs=1, 816 number_of_base_filters=nfilt, 817 number_of_feature_filters=nff, 818 number_of_back_projection_stages=nbp, 819 convolution_kernel_size=(convn, convn, convn), 820 strides=(strider[0], strider[1], strider[2]), 821 last_convolution=(lastconv, lastconv, lastconv), 822 number_of_loss_functions=1, interpolation='linear') 823 else: 824 if freeze: 825 for layer in segmentation_model.layers: 826 layer.trainable = False 827 if verbose: 828 print( "len intensity_model layers : " + str( len( intensity_model.layers ))) 829 print( "len intensity_model weights : " + str( len( intensity_model.weights ))) 830 print( "len segmentation_model layers : " + str( len( segmentation_model.layers ))) 831 print( "len segmentation_model weights : " + str( len( segmentation_model.weights ))) 832 if clone_intensity_to_segmentation: 833 for k in range(len( segmentation_model.weights )): 834 if k < len( intensity_model.weights ): 835 if intensity_model.weights[k].shape == segmentation_model.weights[k].shape: 836 segmentation_model.weights[k] = intensity_model.weights[k] 837 inputs = tf.keras.Input(shape=input_image_size) 838 insplit = tf.split( inputs, 2, dimensionality+1) 839 outputs = [ 840 intensity_model( insplit[0] ), 841 tf.nn.sigmoid( segmentation_model( insplit[1] ) ) ] 842 mdlout = tf.concat( outputs, axis=dimensionality+1 ) 843 return Model(inputs=inputs, outputs=mdlout ) 844 if pro_seg > 0 and intensity_model is not None: 845 if verbose and freeze: 846 print("Add a segmentation arm to the end. freeze intensity. intensity_model(seg) => conv => sigmoid") 847 if verbose and not freeze: 848 print("Add a segmentation arm to the end. freeze intensity. intensity_model(seg) => conv => sigmoid") 849 if freeze: 850 for layer in intensity_model.layers: 851 layer.trainable = False 852 if dimensionality == 2 : 853 input_image_size = (None,None,2) 854 elif dimensionality == 3 : 855 input_image_size = (None, None,None,2) 856 if dimensionality == 2: 857 myconv = Conv2D 858 firstConv = (convn,convn) 859 firstStrides=(1,1) 860 smashConv=(pro_seg,pro_seg) 861 if dimensionality == 3: 862 myconv = Conv3D 863 firstConv = (convn,convn,convn) 864 firstStrides=(1,1,1) 865 smashConv=(pro_seg,pro_seg,pro_seg) 866 inputs = tf.keras.Input(shape=input_image_size) 867 insplit = tf.split( inputs, 2, dimensionality+1) 868 # define segmentation arm 869 seggit = intensity_model( insplit[1] ) 870 L0 = myconv(filters=nff, 871 kernel_size=firstConv, 872 strides=firstStrides, 873 kernel_initializer='glorot_uniform', 874 padding='same')(seggit) 875 L1 = myconv(filters=nff, 876 kernel_size=firstConv, 877 strides=firstStrides, 878 kernel_initializer='glorot_uniform', 879 padding='same')(L0) 880 L2 = myconv(filters=1, 881 kernel_size=smashConv, 882 strides=firstStrides, 883 kernel_initializer='glorot_uniform', 884 padding='same')(L1) 885 outputs = [ 886 intensity_model( insplit[0] ), 887 tf.nn.sigmoid( L2 ) ] 888 mdlout = tf.concat( outputs, axis=dimensionality+1 ) 889 return Model(inputs=inputs, outputs=mdlout ) 890 return mdl 891 892def image_patch_training_data_from_filenames( 893 filenames, 894 target_patch_size, 895 target_patch_size_low, 896 nPatches = 128, 897 istest = False, 898 patch_scaler=True, 899 to_tensorflow = False, 900 verbose = False 901 ): 902 """ 903 Generates a batch of paired high- and low-resolution image patches for training. 904 905 This function creates training data by taking a list of high-resolution source 906 images, extracting random patches, and then downsampling them to create 907 low-resolution counterparts. This provides the (input, ground_truth) pairs 908 needed to train a super-resolution model. 909 910 Parameters 911 ---------- 912 filenames : list of str 913 A list of file paths to the high-resolution source images. 914 915 target_patch_size : tuple or list of int 916 The dimensions of the high-resolution (ground truth) patch to extract, 917 e.g., `(128, 128, 128)`. 918 919 target_patch_size_low : tuple or list of int 920 The dimensions of the low-resolution (input) patch to generate. The ratio 921 between `target_patch_size` and `target_patch_size_low` determines the 922 super-resolution factor. 923 924 nPatches : int, optional 925 The number of patch pairs to generate in this batch. Default is 128. 926 927 istest : bool, optional 928 If True, the function also generates a third output array containing patches 929 that have been naively upsampled using linear interpolation. This is useful 930 for calculating baseline evaluation metrics (e.g., PSNR) against which the 931 model's performance can be compared. Default is False. 932 933 patch_scaler : bool, optional 934 If True, scales the intensity of each high-resolution patch to the [0, 1] 935 range before creating the downsampled version. This can help with 936 training stability. Default is True. 937 938 to_tensorflow : bool, optional 939 If True, casts the output NumPy arrays to TensorFlow tensors. Default is False. 940 941 verbose : bool, optional 942 If True, prints progress messages during patch generation. Default is False. 943 944 Returns 945 ------- 946 tuple 947 A tuple of NumPy arrays or TensorFlow tensors. 948 - If `istest` is False: `(patchesResam, patchesOrig)` 949 - `patchesResam`: The batch of low-resolution input patches (X_train). 950 - `patchesOrig`: The batch of high-resolution ground truth patches (y_train). 951 - If `istest` is True: `(patchesResam, patchesOrig, patchesUp)` 952 - `patchesUp`: The batch of baseline, linearly-upsampled patches. 953 """ 954 if verbose: 955 print("begin image_patch_training_data_from_filenames") 956 tardim = len( target_patch_size ) 957 strider = [] 958 for j in range( tardim ): 959 strider.append( np.round( target_patch_size[j]/target_patch_size_low[j]) ) 960 if tardim == 3: 961 shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],target_patch_size[2],1) 962 shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],target_patch_size_low[2],1) 963 if tardim == 2: 964 shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],1) 965 shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],1) 966 patchesOrig = np.zeros(shape=shaperhi) 967 patchesResam = np.zeros(shape=shaperlo) 968 patchesUp = None 969 if istest: 970 patchesUp = np.zeros(shape=patchesOrig.shape) 971 for myn in range(nPatches): 972 if verbose: 973 print(myn) 974 imgfn = random.sample( filenames, 1 )[0] 975 if verbose: 976 print(imgfn) 977 img = ants.image_read( imgfn ).iMath("Normalize") 978 if img.components > 1: 979 img = ants.split_channels(img)[0] 980 img = ants.crop_image( img, ants.threshold_image( img, 0.05, 1 ) ) 981 ants.set_origin( img, ants.get_center_of_mass(img) ) 982 img = ants.iMath(img,"Normalize") 983 spc = ants.get_spacing( img ) 984 newspc = [] 985 for jj in range(len(spc)): 986 newspc.append(spc[jj]*strider[jj]) 987 interp_type = random.choice( [0,1] ) 988 if True: 989 imgp = get_random_patch( img, target_patch_size ) 990 imgpmin = imgp.min() 991 if patch_scaler: 992 imgp = imgp - imgpmin 993 imgpmax = imgp.max() 994 if imgpmax > 0 : 995 imgp = imgp / imgpmax 996 rimgp = ants.resample_image( imgp, newspc, use_voxels = False, interp_type=interp_type ) 997 if istest: 998 rimgbi = ants.resample_image( rimgp, spc, use_voxels = False, interp_type=0 ) 999 if tardim == 3: 1000 patchesOrig[myn,:,:,:,0] = imgp.numpy() 1001 patchesResam[myn,:,:,:,0] = rimgp.numpy() 1002 if istest: 1003 patchesUp[myn,:,:,:,0] = rimgbi.numpy() 1004 if tardim == 2: 1005 patchesOrig[myn,:,:,0] = imgp.numpy() 1006 patchesResam[myn,:,:,0] = rimgp.numpy() 1007 if istest: 1008 patchesUp[myn,:,:,0] = rimgbi.numpy() 1009 if to_tensorflow: 1010 patchesOrig = tf.cast( patchesOrig, "float32") 1011 patchesResam = tf.cast( patchesResam, "float32") 1012 if istest: 1013 if to_tensorflow: 1014 patchesUp = tf.cast( patchesUp, "float32") 1015 return patchesResam, patchesOrig, patchesUp 1016 1017 1018def seg_patch_training_data_from_filenames( 1019 filenames, 1020 target_patch_size, 1021 target_patch_size_low, 1022 nPatches = 128, 1023 istest = False, 1024 patch_scaler=True, 1025 to_tensorflow = False, 1026 verbose = False 1027 ): 1028 """ 1029 Generates a batch of paired training data containing both images and segmentations. 1030 1031 This function extends `image_patch_training_data_from_filenames` by adding a 1032 second channel to the data. For each extracted image patch, it also generates 1033 a corresponding segmentation mask using Otsu's thresholding. This is useful for 1034 training multi-task models that perform super-resolution on both an image and 1035 its associated segmentation simultaneously. 1036 1037 Parameters 1038 ---------- 1039 filenames : list of str 1040 A list of file paths to the high-resolution source images. 1041 1042 target_patch_size : tuple or list of int 1043 The dimensions of the high-resolution patch, e.g., `(128, 128, 128)`. 1044 1045 target_patch_size_low : tuple or list of int 1046 The dimensions of the low-resolution input patch. 1047 1048 nPatches : int, optional 1049 The number of patch pairs to generate. Default is 128. 1050 1051 istest : bool, optional 1052 If True, also generates a third output array containing baseline upsampled 1053 intensity images (channel 0 only). Default is False. 1054 1055 patch_scaler : bool, optional 1056 If True, scales the intensity of each image patch to the [0, 1] range. 1057 Default is True. 1058 1059 to_tensorflow : bool, optional 1060 If True, casts the output NumPy arrays to TensorFlow tensors. Default is False. 1061 1062 verbose : bool, optional 1063 If True, prints progress messages. Default is False. 1064 1065 Returns 1066 ------- 1067 tuple 1068 A tuple of multi-channel NumPy arrays or TensorFlow tensors. The structure 1069 is the same as `image_patch_training_data_from_filenames`, but each 1070 array has a channel dimension of 2: 1071 - Channel 0: The intensity image. 1072 - Channel 1: The binary segmentation mask. 1073 """ 1074 if verbose: 1075 print("begin seg_patch_training_data_from_filenames") 1076 tardim = len( target_patch_size ) 1077 strider = [] 1078 nchan = 2 1079 for j in range( tardim ): 1080 strider.append( np.round( target_patch_size[j]/target_patch_size_low[j]) ) 1081 if tardim == 3: 1082 shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],target_patch_size[2],nchan) 1083 shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],target_patch_size_low[2],nchan) 1084 if tardim == 2: 1085 shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],nchan) 1086 shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],nchan) 1087 patchesOrig = np.zeros(shape=shaperhi) 1088 patchesResam = np.zeros(shape=shaperlo) 1089 patchesUp = None 1090 if istest: 1091 patchesUp = np.zeros(shape=patchesOrig.shape) 1092 for myn in range(nPatches): 1093 if verbose: 1094 print(myn) 1095 imgfn = random.sample( filenames, 1 )[0] 1096 if verbose: 1097 print(imgfn) 1098 img = ants.image_read( imgfn ).iMath("Normalize") 1099 if img.components > 1: 1100 img = ants.split_channels(img)[0] 1101 img = ants.crop_image( img, ants.threshold_image( img, 0.05, 1 ) ) 1102 ants.set_origin( img, ants.get_center_of_mass(img) ) 1103 img = ants.iMath(img,"Normalize") 1104 spc = ants.get_spacing( img ) 1105 newspc = [] 1106 for jj in range(len(spc)): 1107 newspc.append(spc[jj]*strider[jj]) 1108 interp_type = random.choice( [0,1] ) 1109 seg_class = random.choice( [1,2] ) 1110 if True: 1111 imgp = get_random_patch( img, target_patch_size ) 1112 imgpmin = imgp.min() 1113 if patch_scaler: 1114 imgp = imgp - imgpmin 1115 imgpmax = imgp.max() 1116 if imgpmax > 0 : 1117 imgp = imgp / imgpmax 1118 segp = ants.threshold_image( imgp, "Otsu", 2 ).threshold_image( seg_class, seg_class ) 1119 rimgp = ants.resample_image( imgp, newspc, use_voxels = False, interp_type=interp_type ) 1120 rsegp = ants.resample_image( segp, newspc, use_voxels = False, interp_type=interp_type ) 1121 if istest: 1122 rimgbi = ants.resample_image( rimgp, spc, use_voxels = False, interp_type=0 ) 1123 if tardim == 3: 1124 patchesOrig[myn,:,:,:,0] = imgp.numpy() 1125 patchesResam[myn,:,:,:,0] = rimgp.numpy() 1126 patchesOrig[myn,:,:,:,1] = segp.numpy() 1127 patchesResam[myn,:,:,:,1] = rsegp.numpy() 1128 if istest: 1129 patchesUp[myn,:,:,:,0] = rimgbi.numpy() 1130 if tardim == 2: 1131 patchesOrig[myn,:,:,0] = imgp.numpy() 1132 patchesResam[myn,:,:,0] = rimgp.numpy() 1133 patchesOrig[myn,:,:,1] = segp.numpy() 1134 patchesResam[myn,:,:,1] = rsegp.numpy() 1135 if istest: 1136 patchesUp[myn,:,:,0] = rimgbi.numpy() 1137 if to_tensorflow: 1138 patchesOrig = tf.cast( patchesOrig, "float32") 1139 patchesResam = tf.cast( patchesResam, "float32") 1140 if istest: 1141 if to_tensorflow: 1142 patchesUp = tf.cast( patchesUp, "float32") 1143 return patchesResam, patchesOrig, patchesUp 1144 1145def read( filename ): 1146 """ 1147 Reads an image or a NumPy array from a file. 1148 1149 This function acts as a wrapper to intelligently load data. It checks the 1150 file extension to decide whether to use `ants.image_read` for standard 1151 medical image formats (e.g., .nii.gz, .mha) or `numpy.load` for `.npy` files. 1152 1153 Parameters 1154 ---------- 1155 filename : str 1156 The full path to the file to be read. 1157 1158 Returns 1159 ------- 1160 ants.ANTsImage or np.ndarray 1161 The loaded data object, either as an ANTsImage or a NumPy array. 1162 """ 1163 import re 1164 isnpy = len( re.sub( ".npy", "", filename ) ) != len( filename ) 1165 if not isnpy: 1166 myoutput = ants.image_read( filename ) 1167 else: 1168 myoutput = np.load( filename ) 1169 return myoutput 1170 1171 1172def auto_weight_loss( mdl, feature_extractor, x, y, feature=2.0, tv=0.1, verbose=True ): 1173 """ 1174 Automatically compute weighting coefficients for a combined loss function 1175 based on intensity (MSE), perceptual similarity (feature), and total variation (TV). 1176 1177 Parameters 1178 ---------- 1179 mdl : tf.keras.Model 1180 A trained or untrained model to evaluate predictions on input `x`. 1181 1182 feature_extractor : tf.keras.Model 1183 A model that extracts intermediate features from the input. Commonly a VGG or ResNet 1184 trained on a perceptual task. 1185 1186 x : tf.Tensor 1187 Input batch to the model. 1188 1189 y : tf.Tensor 1190 Ground truth target for `x`, typically a batch of 2D or 3D volumes. 1191 1192 feature : float, optional 1193 Weighting factor for the feature (perceptual) term in the loss. Default is 2.0. 1194 1195 tv : float, optional 1196 Weighting factor for the total variation term in the loss. Default is 0.1. 1197 1198 verbose : bool, optional 1199 If True, prints each component of the loss and its scaled value. 1200 1201 Returns 1202 ------- 1203 list of float 1204 A list of computed weights in the order: 1205 `[msq_weight, feature_weight, tv_weight]` 1206 1207 Notes 1208 ----- 1209 The total loss (to be used during training) can then be constructed as: 1210 1211 `L = msq_weight * MSE + feature_weight * perceptual_loss + tv_weight * TV` 1212 1213 This function is typically used to balance loss terms before training. 1214 """ 1215 y_pred = mdl( x ) 1216 squared_difference = tf.square( y - y_pred) 1217 if len( y.shape ) == 5: 1218 tdim = 3 1219 myax = [1,2,3,4] 1220 if len( y.shape ) == 4: 1221 tdim = 2 1222 myax = [1,2,3] 1223 msqTerm = tf.reduce_mean(squared_difference, axis=myax) 1224 temp1 = feature_extractor(y) 1225 temp2 = feature_extractor(y_pred) 1226 feature_difference = tf.square(temp1-temp2) 1227 featureTerm = tf.reduce_mean(feature_difference, axis=myax) 1228 msqw = 10.0 1229 featw = feature * msqw * msqTerm / featureTerm 1230 mytv = tf.cast( 0.0, 'float32') 1231 if tdim == 3: 1232 for k in range( y_pred.shape[0] ): # BUG not sure why myr fails .... might be old TF version 1233 sqzd = y_pred[k,:,:,:,:] 1234 mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) ) 1235 if tdim == 2: 1236 mytv = tf.reduce_mean( tf.image.total_variation( y_pred ) ) 1237 tvw = tv * msqw * msqTerm / mytv 1238 if verbose : 1239 print( "MSQ: " + str( msqw * msqTerm ) ) 1240 print( "Feat: " + str( featw * featureTerm ) ) 1241 print( "Tv: " + str( mytv * tvw ) ) 1242 wts = [msqw,featw.numpy().mean(),tvw.numpy().mean()] 1243 return wts 1244 1245def auto_weight_loss_seg( mdl, feature_extractor, x, y, feature=2.0, tv=0.1, dice=0.5, verbose=True ): 1246 """ 1247 Automatically compute weighting coefficients for a combined loss function 1248 that includes MSE, perceptual similarity, total variation, and segmentation Dice loss. 1249 1250 Parameters 1251 ---------- 1252 mdl : tf.keras.Model 1253 A segmentation + super-resolution model that outputs both image and label predictions. 1254 1255 feature_extractor : tf.keras.Model 1256 Feature extractor model used to compute perceptual similarity loss. 1257 1258 x : tf.Tensor 1259 Input tensor to the model. 1260 1261 y : tf.Tensor 1262 Target tensor with two channels: [intensity_image, segmentation_label]. 1263 1264 feature : float, optional 1265 Relative weight of the perceptual feature loss term. Default is 2.0. 1266 1267 tv : float, optional 1268 Relative weight of the total variation (TV) term. Default is 0.1. 1269 1270 dice : float, optional 1271 Relative weight of the Dice loss term (for segmentation agreement). Default is 0.5. 1272 1273 verbose : bool, optional 1274 If True, prints the scaled values of each component loss. 1275 1276 Returns 1277 ------- 1278 list of float 1279 A list of loss term weights in the order: 1280 `[msq_weight, feature_weight, tv_weight, dice_weight]` 1281 1282 Notes 1283 ----- 1284 - The input and output tensors must be shaped such that the last axis is 2: 1285 channel 0 is intensity, channel 1 is segmentation. 1286 - This is useful for dual-task networks that predict both high-res images 1287 and associated segmentation masks. 1288 1289 See Also 1290 -------- 1291 binary_dice_loss : Computes Dice loss between predicted and ground-truth masks. 1292 """ 1293 y_pred = mdl( x ) 1294 if len( y.shape ) == 5: 1295 tdim = 3 1296 myax = [1,2,3,4] 1297 if len( y.shape ) == 4: 1298 tdim = 2 1299 myax = [1,2,3] 1300 y_intensity = tf.split( y, 2, axis=tdim+1 )[0] 1301 y_seg = tf.split( y, 2, axis=tdim+1 )[1] 1302 y_intensity_p = tf.split( y_pred, 2, axis=tdim+1 )[0] 1303 y_seg_p = tf.split( y_pred, 2, axis=tdim+1 )[1] 1304 squared_difference = tf.square( y_intensity - y_intensity_p ) 1305 msqTerm = tf.reduce_mean(squared_difference, axis=myax) 1306 temp1 = feature_extractor(y_intensity) 1307 temp2 = feature_extractor(y_intensity_p) 1308 feature_difference = tf.square(temp1-temp2) 1309 featureTerm = tf.reduce_mean(feature_difference, axis=myax) 1310 msqw = 10.0 1311 featw = feature * msqw * msqTerm / featureTerm 1312 mytv = tf.cast( 0.0, 'float32') 1313 if tdim == 3: 1314 for k in range( y_pred.shape[0] ): # BUG not sure why myr fails .... might be old TF version 1315 sqzd = y_pred[k,:,:,:,0] 1316 mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) ) 1317 if tdim == 2: 1318 mytv = tf.reduce_mean( tf.image.total_variation( y_pred[:,:,:,0] ) ) 1319 tvw = tv * msqw * msqTerm / mytv 1320 mydice = binary_dice_loss( y_seg, y_seg_p ) 1321 mydice = tf.reduce_mean( mydice ) 1322 dicew = dice * msqw * msqTerm / mydice 1323 dicewt = np.abs( dicew.numpy().mean() ) 1324 if verbose : 1325 print( "MSQ: " + str( msqw * msqTerm ) ) 1326 print( "Feat: " + str( featw * featureTerm ) ) 1327 print( "Tv: " + str( mytv * tvw ) ) 1328 print( "Dice: " + str( mydice * dicewt ) ) 1329 wts = [msqw,featw.numpy().mean(),tvw.numpy().mean(), dicewt ] 1330 return wts 1331 1332def numpy_generator( filenames ): 1333 """ 1334 A placeholder or stub for a data generator. 1335 1336 This generator yields a tuple of `None` values once and then stops. It is 1337 likely intended as a template or for debugging purposes where a generator 1338 object is required but no actual data needs to be processed. 1339 1340 Parameters 1341 ---------- 1342 filenames : any 1343 An argument that is not used by the function. 1344 1345 Yields 1346 ------ 1347 tuple 1348 A single tuple `(None, None, None)`. 1349 """ 1350 patchesResam=patchesOrig=patchesUp=None 1351 yield (patchesResam, patchesOrig,patchesUp) 1352 1353def image_generator( 1354 filenames, 1355 nPatches, 1356 target_patch_size, 1357 target_patch_size_low, 1358 patch_scaler=True, 1359 istest=False, 1360 verbose = False ): 1361 """ 1362 Creates an infinite generator of paired image patches for model training. 1363 1364 This function continuously generates batches of low-resolution (input) and 1365 high-resolution (ground truth) image patches. It is designed to be fed 1366 directly into a Keras `model.fit()` call. 1367 1368 Parameters 1369 ---------- 1370 filenames : list of str 1371 List of file paths to the high-resolution source images. 1372 nPatches : int 1373 The number of patch pairs to generate and yield in each batch. 1374 target_patch_size : tuple or list of int 1375 The dimensions of the high-resolution (ground truth) patches. 1376 target_patch_size_low : tuple or list of int 1377 The dimensions of the low-resolution (input) patches. 1378 patch_scaler : bool, optional 1379 If True, scales patch intensities to [0, 1]. Default is True. 1380 istest : bool, optional 1381 If True, the generator will also yield a third item: a baseline 1382 linearly upsampled version of the low-resolution patch for comparison. 1383 Default is False. 1384 verbose : bool, optional 1385 If True, passes verbosity to the underlying patch generation function. 1386 Default is False. 1387 1388 Yields 1389 ------- 1390 tuple 1391 A tuple of TensorFlow tensors ready for training or evaluation. 1392 - If `istest` is False: `(low_res_batch, high_res_batch)` 1393 - If `istest` is True: `(low_res_batch, high_res_batch, baseline_upsampled_batch)` 1394 1395 See Also 1396 -------- 1397 image_patch_training_data_from_filenames : The function that performs the 1398 underlying patch extraction. 1399 """ 1400 while True: 1401 patchesResam, patchesOrig, patchesUp = image_patch_training_data_from_filenames( 1402 filenames, 1403 target_patch_size = target_patch_size, 1404 target_patch_size_low = target_patch_size_low, 1405 nPatches = nPatches, 1406 istest = istest, 1407 patch_scaler=patch_scaler, 1408 to_tensorflow = True, 1409 verbose = verbose ) 1410 if istest: 1411 yield (patchesResam, patchesOrig,patchesUp) 1412 yield (patchesResam, patchesOrig) 1413 1414 1415def seg_generator( 1416 filenames, 1417 nPatches, 1418 target_patch_size, 1419 target_patch_size_low, 1420 patch_scaler=True, 1421 istest=False, 1422 verbose = False ): 1423 """ 1424 Creates an infinite generator of paired image and segmentation patches. 1425 1426 This function continuously generates batches of multi-channel patches, where 1427 one channel is the intensity image and the other is a segmentation mask. 1428 It is designed for training multi-task super-resolution models. 1429 1430 Parameters 1431 ---------- 1432 filenames : list of str 1433 List of file paths to the high-resolution source images. 1434 nPatches : int 1435 The number of patch pairs to generate and yield in each batch. 1436 target_patch_size : tuple or list of int 1437 The dimensions of the high-resolution patches. 1438 target_patch_size_low : tuple or list of int 1439 The dimensions of the low-resolution patches. 1440 patch_scaler : bool, optional 1441 If True, scales the intensity channel of patches to [0, 1]. Default is True. 1442 istest : bool, optional 1443 If True, yields an additional baseline upsampled patch for comparison. 1444 Default is False. 1445 verbose : bool, optional 1446 If True, passes verbosity to the underlying patch generation function. 1447 Default is False. 1448 1449 Yields 1450 ------- 1451 tuple 1452 A tuple of multi-channel TensorFlow tensors. Each tensor has two channels: 1453 Channel 0 contains the intensity image, and Channel 1 contains the 1454 segmentation mask. 1455 1456 See Also 1457 -------- 1458 seg_patch_training_data_from_filenames : The function that performs the 1459 underlying patch extraction. 1460 image_generator : A similar generator for intensity-only data. 1461 """ 1462 while True: 1463 patchesResam, patchesOrig, patchesUp = seg_patch_training_data_from_filenames( 1464 filenames, 1465 target_patch_size = target_patch_size, 1466 target_patch_size_low = target_patch_size_low, 1467 nPatches = nPatches, 1468 istest = istest, 1469 patch_scaler=patch_scaler, 1470 to_tensorflow = True, 1471 verbose = verbose ) 1472 if istest: 1473 yield (patchesResam, patchesOrig,patchesUp) 1474 yield (patchesResam, patchesOrig) 1475 1476 1477def train( 1478 mdl, 1479 filenames_train, 1480 filenames_test, 1481 target_patch_size, 1482 target_patch_size_low, 1483 output_prefix, 1484 n_test = 8, 1485 learning_rate=5e-5, 1486 feature_layer = 6, 1487 feature = 2, 1488 tv = 0.1, 1489 max_iterations = 1000, 1490 batch_size = 1, 1491 save_all_best = False, 1492 feature_type = 'grader', 1493 check_eval_data_iteration = 20, 1494 verbose = False ): 1495 """ 1496 Orchestrates the training process for a super-resolution model. 1497 1498 This function handles the entire training loop, including setting up data 1499 generators, defining a composite loss function, automatically balancing loss 1500 weights, iteratively training the model, periodically evaluating performance, 1501 and saving the best-performing model weights. 1502 1503 Parameters 1504 ---------- 1505 mdl : tf.keras.Model 1506 The Keras model to be trained. 1507 filenames_train : list of str 1508 List of file paths for the training dataset. 1509 filenames_test : list of str 1510 List of file paths for the validation/testing dataset. 1511 target_patch_size : tuple or list 1512 The dimensions of the high-resolution target patches. 1513 target_patch_size_low : tuple or list 1514 The dimensions of the low-resolution input patches. 1515 output_prefix : str 1516 A prefix for all output files (e.g., model weights, training logs). 1517 n_test : int, optional 1518 The number of validation patches to use for evaluation. Default is 8. 1519 learning_rate : float, optional 1520 The learning rate for the Adam optimizer. Default is 5e-5. 1521 feature_layer : int, optional 1522 The layer index from the feature extractor to use for perceptual loss. 1523 Default is 6. 1524 feature : float, optional 1525 The relative weight of the perceptual (feature) loss term. Default is 2.0. 1526 tv : float, optional 1527 The relative weight of the Total Variation (TV) regularization term. 1528 Default is 0.1. 1529 max_iterations : int, optional 1530 The total number of training iterations to run. Default is 1000. 1531 batch_size : int, optional 1532 The batch size for training. Note: this implementation is optimized for 1533 batch_size=1 and may need adjustment for larger batches. Default is 1. 1534 save_all_best : bool, optional 1535 If True, saves a new model file every time validation loss improves. 1536 If False, overwrites the single best model file. Default is False. 1537 feature_type : str, optional 1538 The type of feature extractor for perceptual loss. Options: 'grader', 1539 'vgg', 'vggrandom'. Default is 'grader'. 1540 check_eval_data_iteration : int, optional 1541 The frequency (in iterations) at which to run validation and save logs. 1542 Default is 20. 1543 verbose : bool, optional 1544 If True, prints detailed progress information. Default is False. 1545 1546 Returns 1547 ------- 1548 pd.DataFrame 1549 A DataFrame containing the training history, with columns for training 1550 loss, validation loss, PSNR, and baseline PSNR over iterations. 1551 """ 1552 colnames = ['train_loss','test_loss','best','eval_psnr','eval_psnr_lin'] 1553 training_path = np.zeros( [ max_iterations, len(colnames) ] ) 1554 training_weights = np.zeros( [1,3] ) 1555 if verbose: 1556 print("begin get feature extractor " + feature_type) 1557 if feature_type == 'grader': 1558 feature_extractor = get_grader_feature_network( feature_layer ) 1559 elif feature_type == 'vggrandom': 1560 with eager_mode(): 1561 feature_extractor = pseudo_3d_vgg_features( target_patch_size, feature_layer, pretrained=False ) 1562 elif feature_type == 'vgg': 1563 with eager_mode(): 1564 feature_extractor = pseudo_3d_vgg_features_unbiased( target_patch_size, feature_layer ) 1565 else: 1566 raise Exception("feature type does not exist") 1567 if verbose: 1568 print("begin train generator") 1569 mydatgen = image_generator( 1570 filenames_train, 1571 nPatches=1, 1572 target_patch_size=target_patch_size, 1573 target_patch_size_low=target_patch_size_low, 1574 istest=False , verbose=False) 1575 if verbose: 1576 print("begin test generator") 1577 mydatgenTest = image_generator( filenames_test, nPatches=1, 1578 target_patch_size=target_patch_size, 1579 target_patch_size_low=target_patch_size_low, 1580 istest=True, verbose=True) 1581 patchesResamTeTf, patchesOrigTeTf, patchesUpTeTf = next( mydatgenTest ) 1582 if len( patchesOrigTeTf.shape ) == 5: 1583 tdim = 3 1584 myax = [1,2,3,4] 1585 if len( patchesOrigTeTf.shape ) == 4: 1586 tdim = 2 1587 myax = [1,2,3] 1588 if verbose: 1589 print("begin train generator #2 at dim: " + str( tdim)) 1590 mydatgenTest = image_generator( filenames_test, nPatches=1, 1591 target_patch_size=target_patch_size, 1592 target_patch_size_low=target_patch_size_low, 1593 istest=True, verbose=True) 1594 patchesResamTeTfB, patchesOrigTeTfB, patchesUpTeTfB = next( mydatgenTest ) 1595 for k in range( n_test - 1 ): 1596 mydatgenTest = image_generator( filenames_test, nPatches=1, 1597 target_patch_size=target_patch_size, 1598 target_patch_size_low=target_patch_size_low, 1599 istest=True, verbose=True) 1600 temp0, temp1, temp2 = next( mydatgenTest ) 1601 patchesResamTeTfB = tf.concat( [patchesResamTeTfB,temp0],axis=0) 1602 patchesOrigTeTfB = tf.concat( [patchesOrigTeTfB,temp1],axis=0) 1603 patchesUpTeTfB = tf.concat( [patchesUpTeTfB,temp2],axis=0) 1604 if verbose: 1605 print("begin auto_weight_loss") 1606 wts_csv = output_prefix + "_training_weights.csv" 1607 if exists( wts_csv ): 1608 wtsdf = pd.read_csv( wts_csv ) 1609 wts = [wtsdf['msq'][0], wtsdf['feat'][0], wtsdf['tv'][0]] 1610 if verbose: 1611 print( "preset weights:" ) 1612 else: 1613 with eager_mode(): 1614 wts = auto_weight_loss( mdl, feature_extractor, patchesResamTeTf, patchesOrigTeTf, 1615 feature=feature, tv=tv ) 1616 for k in range(len(wts)): 1617 training_weights[0,k]=wts[k] 1618 pd.DataFrame(training_weights, columns = ["msq","feat","tv"] ).to_csv( wts_csv ) 1619 if verbose: 1620 print( "automatic weights:" ) 1621 if verbose: 1622 print( wts ) 1623 def my_loss_6( y_true, y_pred, msqwt = wts[0], fw = wts[1], tvwt = wts[2], mybs = batch_size ): 1624 """Composite loss: MSE + Perceptual Loss + Total Variation.""" 1625 squared_difference = tf.square(y_true - y_pred) 1626 if len( y_true.shape ) == 5: 1627 tdim = 3 1628 myax = [1,2,3,4] 1629 if len( y_true.shape ) == 4: 1630 tdim = 2 1631 myax = [1,2,3] 1632 msqTerm = tf.reduce_mean(squared_difference, axis=myax) 1633 temp1 = feature_extractor(y_true) 1634 temp2 = feature_extractor(y_pred) 1635 feature_difference = tf.square(temp1-temp2) 1636 featureTerm = tf.reduce_mean(feature_difference, axis=myax) 1637 loss = msqTerm * msqwt + featureTerm * fw 1638 mytv = tf.cast( 0.0, 'float32') 1639 # mybs = int( y_pred.shape[0] ) --- should work but ... ? 1640 if tdim == 3: 1641 for k in range( mybs ): # BUG not sure why myr fails .... might be old TF version 1642 sqzd = y_pred[k,:,:,:,:] 1643 mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) ) * tvwt 1644 if tdim == 2: 1645 mytv = tf.reduce_mean( tf.image.total_variation( y_pred ) ) * tvwt 1646 return( loss + mytv ) 1647 if verbose: 1648 print("begin model compilation") 1649 opt = tf.keras.optimizers.Adam( learning_rate=learning_rate ) 1650 mdl.compile(optimizer=opt, loss=my_loss_6) 1651 # set up some parameters for tracking performance 1652 bestValLoss=1e12 1653 bestSSIM=0.0 1654 bestQC0 = -1000 1655 bestQC1 = -1000 1656 if verbose: 1657 print( "begin training", flush=True ) 1658 for myrs in range( max_iterations ): 1659 tracker = mdl.fit( mydatgen, epochs=2, steps_per_epoch=4, verbose=1, 1660 validation_data=(patchesResamTeTf,patchesOrigTeTf) ) 1661 training_path[myrs,0]=tracker.history['loss'][0] 1662 training_path[myrs,1]=tracker.history['val_loss'][0] 1663 training_path[myrs,2]=0 1664 print( "ntrain: " + str(myrs) + " loss " + str( tracker.history['loss'][0] ) + ' val-loss ' + str(tracker.history['val_loss'][0]), flush=True ) 1665 if myrs % check_eval_data_iteration == 0: 1666 with tf.device("/cpu:0"): 1667 myofn = output_prefix + "_best_mdl.keras" 1668 if save_all_best: 1669 myofn = output_prefix + "_" + str(myrs)+ "_mdl.keras" 1670 tester = mdl.evaluate( patchesResamTeTfB, patchesOrigTeTfB ) 1671 if ( tester < bestValLoss ): 1672 print("MyIT " + str( myrs ) + " IS BEST!! " + str( tester ) + myofn, flush=True ) 1673 bestValLoss = tester 1674 tf.keras.models.save_model( mdl, myofn ) 1675 training_path[myrs,2]=1 1676 pp = mdl.predict( patchesResamTeTfB, batch_size = 1 ) 1677 myssimSR = tf.image.psnr( pp * 220, patchesOrigTeTfB* 220, max_val=255 ) 1678 myssimSR = tf.reduce_mean( myssimSR ).numpy() 1679 myssimBI = tf.image.psnr( patchesUpTeTfB * 220, patchesOrigTeTfB* 220, max_val=255 ) 1680 myssimBI = tf.reduce_mean( myssimBI ).numpy() 1681 print( myofn + " : " + "PSNR Lin: " + str( myssimBI ) + " SR: " + str( myssimSR ), flush=True ) 1682 training_path[myrs,3]=myssimSR # psnr 1683 training_path[myrs,4]=myssimBI # psnrlin 1684 pd.DataFrame(training_path, columns = colnames ).to_csv( output_prefix + "_training.csv" ) 1685 training_path = pd.DataFrame(training_path, columns = colnames ) 1686 return training_path 1687 1688 1689def binary_dice_loss(y_true, y_pred): 1690 """ 1691 Computes the Dice loss for binary segmentation tasks. 1692 1693 The Dice coefficient is a common metric for comparing the overlap of two samples. 1694 This loss function computes `1 - DiceCoefficient`, making it suitable for 1695 minimization during training. A smoothing factor is added to avoid division 1696 by zero when both the prediction and the ground truth are empty. 1697 1698 Parameters 1699 ---------- 1700 y_true : tf.Tensor 1701 The ground truth binary segmentation mask. Values should be 0 or 1. 1702 y_pred : tf.Tensor 1703 The predicted binary segmentation mask, typically with values in [0, 1] 1704 from a sigmoid activation. 1705 1706 Returns 1707 ------- 1708 tf.Tensor 1709 A scalar tensor representing the Dice loss. The value ranges from -1 (perfect 1710 match) to 0 (no overlap), though it's typically used as `1 - dice_coeff` 1711 or just `-dice_coeff` (as here). 1712 """ 1713 smoothing_factor = 1e-4 1714 K = tf.keras.backend 1715 y_true_f = K.flatten(y_true) 1716 y_pred_f = K.flatten(y_pred) 1717 intersection = K.sum(y_true_f * y_pred_f) 1718 # This is -1 * Dice Similarity Coefficient 1719 return -1 * (2 * intersection + smoothing_factor)/(K.sum(y_true_f) + 1720 K.sum(y_pred_f) + smoothing_factor) 1721 1722def train_seg( 1723 mdl, 1724 filenames_train, 1725 filenames_test, 1726 target_patch_size, 1727 target_patch_size_low, 1728 output_prefix, 1729 n_test = 8, 1730 learning_rate=5e-5, 1731 feature_layer = 6, 1732 feature = 2, 1733 tv = 0.1, 1734 dice = 0.5, 1735 max_iterations = 1000, 1736 batch_size = 1, 1737 save_all_best = False, 1738 feature_type = 'grader', 1739 check_eval_data_iteration = 20, 1740 verbose = False ): 1741 """ 1742 Orchestrates training for a multi-task image and segmentation SR model. 1743 1744 This function extends the `train` function to handle models that predict 1745 both a super-resolved image and a super-resolved segmentation mask. It uses 1746 a four-component composite loss: MSE (for image), a perceptual loss (for 1747 image), Total Variation (for image), and Dice loss (for segmentation). 1748 1749 Parameters 1750 ---------- 1751 mdl : tf.keras.Model 1752 The 2-channel Keras model to be trained. 1753 filenames_train : list of str 1754 List of file paths for the training dataset. 1755 filenames_test : list of str 1756 List of file paths for the validation/testing dataset. 1757 target_patch_size : tuple or list 1758 The dimensions of the high-resolution target patches. 1759 target_patch_size_low : tuple or list 1760 The dimensions of the low-resolution input patches. 1761 output_prefix : str 1762 A prefix for all output files. 1763 n_test : int, optional 1764 Number of validation patches for evaluation. Default is 8. 1765 learning_rate : float, optional 1766 Learning rate for the Adam optimizer. Default is 5e-5. 1767 feature_layer : int, optional 1768 Layer from the feature extractor for perceptual loss. Default is 6. 1769 feature : float, optional 1770 Relative weight of the perceptual loss term. Default is 2.0. 1771 tv : float, optional 1772 Relative weight of the Total Variation regularization term. Default is 0.1. 1773 dice : float, optional 1774 Relative weight of the Dice loss term for the segmentation mask. 1775 Default is 0.5. 1776 max_iterations : int, optional 1777 Total number of training iterations. Default is 1000. 1778 batch_size : int, optional 1779 The batch size for training. Default is 1. 1780 save_all_best : bool, optional 1781 If True, saves all models that improve validation loss. Default is False. 1782 feature_type : str, optional 1783 Type of feature extractor for perceptual loss. Default is 'grader'. 1784 check_eval_data_iteration : int, optional 1785 Frequency (in iterations) for running validation. Default is 20. 1786 verbose : bool, optional 1787 If True, prints detailed progress information. Default is False. 1788 1789 Returns 1790 ------- 1791 pd.DataFrame 1792 A DataFrame containing the training history, including columns for losses 1793 and evaluation metrics like PSNR and Dice score. 1794 1795 See Also 1796 -------- 1797 train : The training function for single-task (intensity-only) models. 1798 """ 1799 colnames = ['train_loss','test_loss','best','eval_psnr','eval_psnr_lin','eval_msq','eval_dice'] 1800 training_path = np.zeros( [ max_iterations, len(colnames) ] ) 1801 training_weights = np.zeros( [1,4] ) 1802 if verbose: 1803 print("begin get feature extractor") 1804 if feature_type == 'grader': 1805 feature_extractor = get_grader_feature_network( feature_layer ) 1806 elif feature_type == 'vggrandom': 1807 feature_extractor = pseudo_3d_vgg_features( target_patch_size, feature_layer, pretrained=False ) 1808 else: 1809 feature_extractor = pseudo_3d_vgg_features_unbiased( target_patch_size, feature_layer ) 1810 if verbose: 1811 print("begin train generator") 1812 mydatgen = seg_generator( 1813 filenames_train, 1814 nPatches=1, 1815 target_patch_size=target_patch_size, 1816 target_patch_size_low=target_patch_size_low, 1817 istest=False , verbose=False) 1818 if verbose: 1819 print("begin test generator") 1820 mydatgenTest = seg_generator( filenames_test, nPatches=1, 1821 target_patch_size=target_patch_size, 1822 target_patch_size_low=target_patch_size_low, 1823 istest=True, verbose=True) 1824 patchesResamTeTf, patchesOrigTeTf, patchesUpTeTf = next( mydatgenTest ) 1825 if len( patchesOrigTeTf.shape ) == 5: 1826 tdim = 3 1827 myax = [1,2,3,4] 1828 if len( patchesOrigTeTf.shape ) == 4: 1829 tdim = 2 1830 myax = [1,2,3] 1831 if verbose: 1832 print("begin train generator #2 at dim: " + str( tdim)) 1833 mydatgenTest = seg_generator( filenames_test, nPatches=1, 1834 target_patch_size=target_patch_size, 1835 target_patch_size_low=target_patch_size_low, 1836 istest=True, verbose=True) 1837 patchesResamTeTfB, patchesOrigTeTfB, patchesUpTeTfB = next( mydatgenTest ) 1838 for k in range( n_test - 1 ): 1839 mydatgenTest = seg_generator( filenames_test, nPatches=1, 1840 target_patch_size=target_patch_size, 1841 target_patch_size_low=target_patch_size_low, 1842 istest=True, verbose=True) 1843 temp0, temp1, temp2 = next( mydatgenTest ) 1844 patchesResamTeTfB = tf.concat( [patchesResamTeTfB,temp0],axis=0) 1845 patchesOrigTeTfB = tf.concat( [patchesOrigTeTfB,temp1],axis=0) 1846 patchesUpTeTfB = tf.concat( [patchesUpTeTfB,temp2],axis=0) 1847 if verbose: 1848 print("begin auto_weight_loss_seg") 1849 wts_csv = output_prefix + "_training_weights.csv" 1850 if exists( wts_csv ): 1851 wtsdf = pd.read_csv( wts_csv ) 1852 wts = [wtsdf['msq'][0], wtsdf['feat'][0], wtsdf['tv'][0], wtsdf['dice'][0]] 1853 if verbose: 1854 print( "preset weights:" ) 1855 else: 1856 wts = auto_weight_loss_seg( mdl, feature_extractor, patchesResamTeTf, patchesOrigTeTf, 1857 feature=feature, tv=tv, dice=dice ) 1858 for k in range(len(wts)): 1859 training_weights[0,k]=wts[k] 1860 pd.DataFrame(training_weights, columns = ["msq","feat","tv","dice"] ).to_csv( wts_csv ) 1861 if verbose: 1862 print( "automatic weights:" ) 1863 if verbose: 1864 print( wts ) 1865 def my_loss_6( y_true, y_pred, msqwt = wts[0], fw = wts[1], tvwt = wts[2], dicewt=wts[3], mybs = batch_size ): 1866 """Composite loss: MSE + Perceptual + TV + Dice.""" 1867 if len( y_true.shape ) == 5: 1868 tdim = 3 1869 myax = [1,2,3,4] 1870 if len( y_true.shape ) == 4: 1871 tdim = 2 1872 myax = [1,2,3] 1873 y_intensity = tf.split( y_true, 2, axis=tdim+1 )[0] 1874 y_seg = tf.split( y_true, 2, axis=tdim+1 )[1] 1875 y_intensity_p = tf.split( y_pred, 2, axis=tdim+1 )[0] 1876 y_seg_p = tf.split( y_pred, 2, axis=tdim+1 )[1] 1877 squared_difference = tf.square(y_intensity - y_intensity_p) 1878 msqTerm = tf.reduce_mean(squared_difference, axis=myax) 1879 temp1 = feature_extractor(y_intensity) 1880 temp2 = feature_extractor(y_intensity_p) 1881 feature_difference = tf.square(temp1-temp2) 1882 featureTerm = tf.reduce_mean(feature_difference, axis=myax) 1883 loss = msqTerm * msqwt + featureTerm * fw 1884 mytv = tf.cast( 0.0, 'float32') 1885 if tdim == 3: 1886 for k in range( mybs ): # BUG not sure why myr fails .... might be old TF version 1887 sqzd = y_pred[k,:,:,:,0] 1888 mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) ) * tvwt 1889 if tdim == 2: 1890 mytv = tf.reduce_mean( tf.image.total_variation( y_pred[:,:,:,0] ) ) * tvwt 1891 dicer = tf.reduce_mean( dicewt * binary_dice_loss( y_seg, y_seg_p ) ) 1892 return( loss + mytv + dicer ) 1893 if verbose: 1894 print("begin model compilation") 1895 opt = tf.keras.optimizers.Adam( learning_rate=learning_rate ) 1896 mdl.compile(optimizer=opt, loss=my_loss_6) 1897 # set up some parameters for tracking performance 1898 bestValLoss=1e12 1899 bestSSIM=0.0 1900 bestQC0 = -1000 1901 bestQC1 = -1000 1902 if verbose: 1903 print( "begin training", flush=True ) 1904 for myrs in range( max_iterations ): 1905 tracker = mdl.fit( mydatgen, epochs=2, steps_per_epoch=4, verbose=1, 1906 validation_data=(patchesResamTeTf,patchesOrigTeTf) ) 1907 training_path[myrs,0]=tracker.history['loss'][0] 1908 training_path[myrs,1]=tracker.history['val_loss'][0] 1909 training_path[myrs,2]=0 1910 print( "ntrain: " + str(myrs) + " loss " + str( tracker.history['loss'][0] ) + ' val-loss ' + str(tracker.history['val_loss'][0]), flush=True ) 1911 if myrs % check_eval_data_iteration == 0: 1912 with tf.device("/cpu:0"): 1913 myofn = output_prefix + "_best_mdl.keras" 1914 if save_all_best: 1915 myofn = output_prefix + "_" + str(myrs)+ "_mdl.keras" 1916 tester = mdl.evaluate( patchesResamTeTfB, patchesOrigTeTfB ) 1917 if ( tester < bestValLoss ): 1918 print("MyIT " + str( myrs ) + " IS BEST!! " + str( tester ) + myofn, flush=True ) 1919 bestValLoss = tester 1920 tf.keras.models.save_model( mdl, myofn ) 1921 training_path[myrs,2]=1 1922 pp = mdl.predict( patchesResamTeTfB, batch_size = 1 ) 1923 pp = tf.split( pp, 2, axis=tdim+1 ) 1924 y_orig = tf.split( patchesOrigTeTfB, 2, axis=tdim+1 ) 1925 y_up = tf.split( patchesUpTeTfB, 2, axis=tdim+1 ) 1926 myssimSR = tf.image.psnr( pp[0] * 220, y_orig[0]* 220, max_val=255 ) 1927 myssimSR = tf.reduce_mean( myssimSR ).numpy() 1928 myssimBI = tf.image.psnr( y_up[0] * 220, y_orig[0]* 220, max_val=255 ) 1929 myssimBI = tf.reduce_mean( myssimBI ).numpy() 1930 squared_difference = tf.square(y_orig[0] - pp[0]) 1931 msqTerm = tf.reduce_mean(squared_difference).numpy() 1932 dicer = binary_dice_loss( y_orig[1], pp[1] ) 1933 dicer = tf.reduce_mean( dicer ).numpy() 1934 print( myofn + " : " + "PSNR Lin: " + str( myssimBI ) + " SR: " + str( myssimSR ) + " MSQ: " + str(msqTerm) + " DICE: " + str(dicer), flush=True ) 1935 training_path[myrs,3]=myssimSR # psnr 1936 training_path[myrs,4]=myssimBI # psnrlin 1937 training_path[myrs,5]=msqTerm # msq 1938 training_path[myrs,6]=dicer # dice 1939 pd.DataFrame(training_path, columns = colnames ).to_csv( output_prefix + "_training.csv" ) 1940 training_path = pd.DataFrame(training_path, columns = colnames ) 1941 return training_path 1942 1943 1944def read_srmodel(srfilename, custom_objects=None): 1945 """ 1946 Load a super-resolution model (h5, .keras, or SavedModel format), 1947 and determine its upsampling factor. 1948 1949 Parameters 1950 ---------- 1951 srfilename : str 1952 Path to the model file (.h5, .keras, or a SavedModel folder). 1953 custom_objects : dict, optional 1954 Dictionary of custom objects used in the model (e.g. {'TFOpLambda': tf.keras.layers.Lambda(...)}) 1955 1956 Returns 1957 ------- 1958 model : tf.keras.Model 1959 The loaded model. 1960 upsampling_factor : list of int 1961 List describing the upsampling factor: 1962 - For 3D input: [x_up, y_up, z_up, channels] 1963 - For 2D input: [x_up, y_up, channels] 1964 1965 Example 1966 ------- 1967 >>> mdl, up = read_srmodel("mymodel.keras") 1968 >>> mdl, up = read_srmodel("my_weights.h5", custom_objects={"TFOpLambda": tf.keras.layers.Lambda(tf.identity)}) 1969 """ 1970 1971 # Expand path and detect format 1972 srfilename = os.path.expanduser(srfilename) 1973 ext = os.path.splitext(srfilename)[1].lower() 1974 1975 if os.path.isdir(srfilename): 1976 # SavedModel directory 1977 model = tf.keras.models.load_model(srfilename, custom_objects=custom_objects, compile=False) 1978 elif ext in ['.h5', '.keras']: 1979 model = tf.keras.models.load_model(srfilename, custom_objects=custom_objects, compile=False) 1980 else: 1981 raise ValueError(f"Unsupported model format: {ext}") 1982 1983 # Determine channel index 1984 input_shape = model.input_shape 1985 if isinstance(input_shape, list): 1986 input_shape = input_shape[0] 1987 chanindex = 3 if len(input_shape) == 4 else 4 1988 nchan = int(input_shape[chanindex]) 1989 1990 # Run dummy input to compute upsampling factor 1991 try: 1992 if len(input_shape) == 5: # 3D 1993 dummy_input = np.zeros([1, 8, 8, 8, nchan]) 1994 else: # 2D 1995 dummy_input = np.zeros([1, 8, 8, nchan]) 1996 1997 # Handle named inputs if necessary 1998 try: 1999 output = model(dummy_input) 2000 except Exception: 2001 output = model({model.input_names[0]: dummy_input}) 2002 2003 outshp = output.shape 2004 if len(input_shape) == 5: 2005 return model, [int(outshp[1]/8), int(outshp[2]/8), int(outshp[3]/8), nchan] 2006 else: 2007 return model, [int(outshp[1]/8), int(outshp[2]/8), nchan] 2008 2009 except Exception as e: 2010 raise RuntimeError(f"Could not infer upsampling factor. Error: {e}") 2011 2012 2013def simulate_image( shaper=[32,32,32], n_levels=10, multiply=False ): 2014 """ 2015 generate an image of given shape and number of levels 2016 2017 Arguments 2018 --------- 2019 shaper : [x,y,z] or [x,y] 2020 2021 n_levels : int 2022 2023 multiply : boolean 2024 2025 Returns 2026 ------- 2027 2028 ants.image 2029 2030 """ 2031 img = ants.from_numpy( np.random.normal( 0, 1.0, size=shaper ) ) * 0 2032 for k in range(n_levels): 2033 temp = ants.from_numpy( np.random.normal( 0, 1.0, size=shaper ) ) 2034 temp = ants.smooth_image( temp, n_levels ) 2035 temp = ants.threshold_image( temp, "Otsu", 1 ) 2036 if multiply: 2037 temp = temp * k 2038 img = img + temp 2039 return img 2040 2041 2042def optimize_upsampling_shape( spacing, modality='T1', roundit=False, verbose=False ): 2043 """ 2044 Compute the optimal upsampling shape string (e.g., '2x2x2') based on image voxel spacing 2045 and imaging modality. This output is used to select an appropriate pretrained 2046 super-resolution model filename. 2047 2048 Parameters 2049 ---------- 2050 spacing : sequence of float 2051 Voxel spacing (physical size per voxel in mm) from the input image. 2052 Typically obtained from `ants.get_spacing(image)`. 2053 2054 modality : str, optional 2055 Imaging modality. Affects resolution thresholds: 2056 - 'T1' : anatomical MRI (default minimum spacing: 0.35 mm) 2057 - 'DTI' : diffusion MRI (default minimum spacing: 1.0 mm) 2058 - 'NM' : nuclear medicine (e.g., PET/SPECT, minimum spacing: 0.25 mm) 2059 2060 roundit : bool, optional 2061 If True, uses rounded integer ratios for the upsampling shape. 2062 Otherwise, uses floor division with constraints. 2063 2064 verbose : bool, optional 2065 If True, prints detailed internal values and logic. 2066 2067 Returns 2068 ------- 2069 str 2070 Optimal upsampling shape string in the form 'AxBxC', 2071 e.g., '2x2x2', '4x4x2'. 2072 2073 Notes 2074 ----- 2075 - The function prevents upsampling ratios that would result in '1x1x1' 2076 by defaulting to '2x2x2'. 2077 - It also avoids uncommon ratios like '5' by rounding to the nearest valid option. 2078 - The returned string is commonly used to populate a model filename template: 2079 2080 Example: 2081 >>> bestup = optimize_upsampling_shape(ants.get_spacing(t1_img), modality='T1') 2082 >>> model = re.sub('bestup', bestup, 'siq_smallshort_train_bestup_1chan.keras') 2083 """ 2084 minspc = min( list( spacing ) ) 2085 maxspc = max( list( spacing ) ) 2086 ratio = maxspc/minspc 2087 if ratio == 1.0: 2088 ratio = 0.5 2089 roundratio = np.round( ratio ) 2090 tarshaperaw = [] 2091 tarshape = [] 2092 tarshaperound = [] 2093 for k in range( len( spacing ) ): 2094 locrat = spacing[k]/minspc 2095 newspc = spacing[k] * roundratio 2096 tarshaperaw.append( locrat ) 2097 if modality == "NM": 2098 if verbose: 2099 print("Using minspacing: 0.25") 2100 if newspc < 0.25 : 2101 locrat = spacing[k]/0.25 2102 elif modality == "DTI": 2103 if verbose: 2104 print("Using minspacing: 1.0") 2105 if newspc < 1.0 : 2106 locrat = spacing[k]/1.0 2107 else: # assume T1 2108 if verbose: 2109 print("Using minspacing: 0.35") 2110 if newspc < 0.35 : 2111 locrat = spacing[k]/0.35 2112 myint = int( locrat ) 2113 if ( myint == 0 ): 2114 myint = 1 2115 if myint == 5: 2116 myint = 4 2117 if ( myint > 6 ): 2118 myint = 6 2119 tarshape.append( str( myint ) ) 2120 tarshaperound.append( str( int(np.round( locrat )) ) ) 2121 if verbose: 2122 print("before emendation:") 2123 print( tarshaperaw ) 2124 print( tarshaperound ) 2125 print( tarshape ) 2126 allone = True 2127 if roundit: 2128 tarshape = tarshaperound 2129 for k in range( len( tarshape ) ): 2130 if tarshape[k] != "1": 2131 allone=False 2132 if allone: 2133 tarshape = ["2","2","2"] # default 2134 return "x".join(tarshape) 2135 2136def compare_models( model_filenames, img, n_classes=3, 2137 poly_order='hist', 2138 identifier=None, noise_sd=0.1,verbose=False ): 2139 """ 2140 Evaluates and compares the performance of multiple super-resolution models on a given image. 2141 2142 This function provides a standardized way to benchmark SR models. For each model, 2143 it performs the following steps: 2144 1. Loads the model and determines its upsampling factor. 2145 2. Downsamples the high-resolution input image (`img`) to create a low-resolution 2146 input, simulating a real-world scenario. 2147 3. Adds Gaussian noise to the low-resolution input to test for robustness. 2148 4. Runs inference using the model to generate a super-resolved output. 2149 5. Generates a baseline output by upsampling the low-res input with linear interpolation. 2150 6. Calculates PSNR and SSIM metrics comparing both the model's output and the 2151 baseline against the original high-resolution image. 2152 7. If a dual-channel (image + segmentation) model is detected, it also calculates 2153 Dice scores for segmentation performance. 2154 8. Aggregates all results into a pandas DataFrame for easy comparison. 2155 2156 Parameters 2157 ---------- 2158 model_filenames : list of str 2159 A list of file paths to the Keras models (.h5, .keras) to be compared. 2160 img : ants.ANTsImage 2161 The high-resolution ground truth image. This image will be downsampled to 2162 create the input for the models. 2163 n_classes : int, optional 2164 The number of classes for Otsu's thresholding when auto-generating a 2165 segmentation for evaluating dual-channel models. Default is 3. 2166 poly_order : str or int, optional 2167 Method for intensity matching between the SR output and the reference. 2168 Options: 'hist' for histogram matching (default), an integer for 2169 polynomial regression, or None to disable. 2170 identifier : str, optional 2171 A custom identifier for the output DataFrame. If None, it is inferred 2172 from the model filename. Default is None. 2173 noise_sd : float, optional 2174 Standard deviation of the additive Gaussian noise applied to the 2175 downsampled image before inference. Default is 0.1. 2176 verbose : bool, optional 2177 If True, prints detailed progress and intermediate values. Default is False. 2178 2179 Returns 2180 ------- 2181 pd.DataFrame 2182 A DataFrame where each row corresponds to a model. Columns contain evaluation 2183 metrics (PSNR.SR, SSIM.SR, DICE.SR), baseline metrics (PSNR.LIN, SSIM.LIN, 2184 DICE.NN), and metadata. 2185 2186 Notes 2187 ----- 2188 When evaluating a 2-channel (segmentation) model, the primary metric for the 2189 segmentation task is the Dice score (`DICE.SR`). The intensity metrics (PSNR, SSIM) 2190 are still computed on the first channel. 2191 """ 2192 padding=4 2193 mydf = pd.DataFrame() 2194 for k in range( len( model_filenames ) ): 2195 srmdl, upshape = read_srmodel( model_filenames[k] ) 2196 if verbose: 2197 print( model_filenames[k] ) 2198 print( upshape ) 2199 tarshape = [] 2200 inspc = ants.get_spacing(img) 2201 for j in range(len(img.shape)): 2202 tarshape.append( float(upshape[j]) * inspc[j] ) 2203 # uses linear interp 2204 dimg=ants.resample_image( img, tarshape, use_voxels=False, interp_type=0 ) 2205 dimg = ants.add_noise_to_image( dimg,'additivegaussian', [0,noise_sd] ) 2206 import math 2207 dicesr=math.nan 2208 dicenn=math.nan 2209 if upshape[3] == 2: 2210 seghigh = ants.threshold_image( img,"Otsu",n_classes) 2211 seglow = ants.resample_image( seghigh, tarshape, use_voxels=False, interp_type=1 ) 2212 dimgup=inference( dimg, srmdl, segmentation = seglow, poly_order=poly_order, verbose=verbose ) 2213 dimgupseg = dimgup['super_resolution_segmentation'] 2214 dimgup = dimgup['super_resolution'] 2215 segblock = ants.resample_image_to_target( seghigh, dimgupseg, interp_type='nearestNeighbor' ) 2216 segimgnn = ants.resample_image_to_target( seglow, dimgupseg, interp_type='nearestNeighbor' ) 2217 segblock[ dimgupseg == 0 ] = 0 2218 segimgnn[ dimgupseg == 0 ] = 0 2219 dicenn = ants.label_overlap_measures(segblock, segimgnn)['MeanOverlap'][0] 2220 dicesr = ants.label_overlap_measures(segblock, dimgupseg)['MeanOverlap'][0] 2221 else: 2222 dimgup=inference( dimg, srmdl, poly_order=poly_order, verbose=verbose ) 2223 dimglin = ants.resample_image_to_target( dimg, dimgup, interp_type='linear' ) 2224 imgblock = ants.resample_image_to_target( img, dimgup, interp_type='linear' ) 2225 dimgup[ imgblock == 0.0 ]=0.0 2226 dimglin[ imgblock == 0.0 ]=0.0 2227 padder = [] 2228 dimwarning=False 2229 for jj in range(img.dimension): 2230 padder.append( padding ) 2231 if img.shape[jj] != imgblock.shape[jj]: 2232 dimwarning=True 2233 if dimwarning: 2234 print("NOTE: dimensions of downsampled to upsampled image do not match!!!") 2235 print("we force them to match but this suggests results may not be reliable.") 2236 temp = os.path.basename( model_filenames[k] ) 2237 temp = re.sub( "siq_default_sisr_", "", temp ) 2238 temp = re.sub( "_best_mdl.keras", "", temp ) 2239 temp = re.sub( "_best_mdl.h5", "", temp ) 2240 if verbose and dimwarning: 2241 print( "original img shape" ) 2242 print( img.shape ) 2243 print( "resampled img shape" ) 2244 print( imgblock.shape ) 2245 a=[] 2246 imgshape = [] 2247 for aa in range(len(upshape)): 2248 a.append( str(upshape[aa]) ) 2249 if aa < len(imgblock.shape): 2250 imgshape.append( str( imgblock.shape[aa] ) ) 2251 if identifier is None: 2252 identifier=temp 2253 mydict = { 2254 "identifier":identifier, 2255 "imgshape":"x".join(imgshape), 2256 "mdl": temp, 2257 "mdlshape":"x".join(a), 2258 "PSNR.LIN": antspynet.psnr( imgblock, dimglin ), 2259 "PSNR.SR": antspynet.psnr( imgblock, dimgup ), 2260 "SSIM.LIN": antspynet.ssim( imgblock, dimglin ), 2261 "SSIM.SR": antspynet.ssim( imgblock, dimgup ), 2262 "DICE.NN": dicenn, 2263 "DICE.SR": dicesr, 2264 "dimwarning": dimwarning } 2265 if verbose: 2266 print( mydict ) 2267 temp = pd.DataFrame.from_records( [mydict], index=[0] ) 2268 mydf = pd.concat( [mydf,temp], axis=0 ) 2269 # end loop 2270 return mydf 2271 2272 2273 2274 2275def region_wise_super_resolution(image, mask, super_res_model, dilation_amount=4, verbose=False): 2276 """ 2277 Apply super-resolution model to each labeled region in the mask independently. 2278 2279 Arguments 2280 --------- 2281 image : ANTsImage 2282 Input image. 2283 2284 mask : ANTsImage 2285 Integer-labeled segmentation mask with non-zero regions to upsample. 2286 2287 super_res_model : tf.keras.Model 2288 Trained super-resolution model. 2289 2290 dilation_amount : int 2291 Number of morphological dilations applied to each label region before cropping. 2292 2293 verbose : bool 2294 If True, print detailed status. 2295 2296 Returns 2297 ------- 2298 ANTsImage : Full-size super-resolved image with per-label inference and stitching. 2299 """ 2300 import ants 2301 import numpy as np 2302 from antspynet import apply_super_resolution_model_to_image 2303 2304 upFactor = [] 2305 input_shape = super_res_model.inputs[0].shape 2306 test_shape = [1, 8, 8, 1] if len(input_shape) == 4 else [1, 8, 8, 8, 1] 2307 test_input = np.zeros(test_shape, dtype=np.float32) 2308 test_output = super_res_model(test_input) 2309 2310 for k in range(len(test_shape) - 2): # ignore batch + channel 2311 upFactor.append(int(test_output.shape[k + 1] / test_input.shape[k + 1])) 2312 2313 original_size = mask.shape # e.g., (x, y, z) 2314 new_size = tuple(int(s * f) for s, f in zip(original_size, upFactor)) 2315 upsampled_mask = ants.resample_image(mask, new_size, use_voxels=True, interp_type=1) 2316 upsampled_image = ants.resample_image(image, new_size, use_voxels=True, interp_type=0) 2317 2318 unique_labels = list(np.unique(upsampled_mask.numpy())) 2319 if 0 in unique_labels: 2320 unique_labels.remove(0) 2321 2322 outimg = ants.image_clone(upsampled_image) 2323 2324 for lab in unique_labels: 2325 if verbose: 2326 print(f"Processing label: {lab}") 2327 regionmask = ants.threshold_image(mask, lab, lab).iMath("MD", dilation_amount) 2328 cropped = ants.crop_image(image, regionmask) 2329 if cropped.shape[0] == 0: 2330 continue 2331 subimgsr = apply_super_resolution_model_to_image( 2332 cropped, super_res_model, target_range=[0, 1], verbose=verbose 2333 ) 2334 stitched = ants.decrop_image(subimgsr, outimg) 2335 outimg[upsampled_mask == lab] = stitched[upsampled_mask == lab] 2336 2337 return outimg 2338 2339 2340def region_wise_super_resolution_blended(image, mask, super_res_model, dilation_amount=4, verbose=False): 2341 """ 2342 Apply super-resolution model to labeled regions with smooth blending to minimize stitching artifacts. 2343 2344 This version uses a weighted-averaging scheme based on distance transforms 2345 to create seamless transitions between super-resolved regions and the background. 2346 2347 Arguments 2348 --------- 2349 image : ANTsImage 2350 Input low-resolution image. 2351 2352 mask : ANTsImage 2353 Integer-labeled segmentation mask. 2354 2355 super_res_model : tf.keras.Model 2356 Trained super-resolution model. 2357 2358 dilation_amount : int 2359 Number of morphological dilations applied to each label region before cropping. 2360 This provides context to the SR model. 2361 2362 verbose : bool 2363 If True, print detailed status. 2364 2365 Returns 2366 ------- 2367 ANTsImage : Full-size, super-resolved image with seamless blending. 2368 """ 2369 import ants 2370 import numpy as np 2371 from antspynet import apply_super_resolution_model_to_image 2372 epsilon32 = np.finfo(np.float32).eps 2373 normalize_weight_maps = True # Default behavior to normalize weight maps 2374 # --- Step 1: Determine upsampling factor and prepare initial images --- 2375 upFactor = [] 2376 input_shape = super_res_model.inputs[0].shape 2377 test_shape = [1, 8, 8, 1] if len(input_shape) == 4 else [1, 8, 8, 8, 1] 2378 test_input = np.zeros(test_shape, dtype=np.float32) 2379 test_output = super_res_model(test_input) 2380 for k in range(len(test_shape) - 2): 2381 upFactor.append(int(test_output.shape[k + 1] / test_input.shape[k + 1])) 2382 2383 original_size = image.shape 2384 new_size = tuple(int(s * f) for s, f in zip(original_size, upFactor)) 2385 2386 # The initial upsampled image will serve as our background 2387 background_sr_image = ants.resample_image(image, new_size, use_voxels=True, interp_type=0) 2388 2389 # --- Step 2: Initialize accumulator and weight sum canvases --- 2390 # These must be float type for accumulation 2391 accumulator = ants.image_clone(background_sr_image).astype('float32') * 0.0 2392 weight_sum = ants.image_clone(accumulator) 2393 2394 unique_labels = [l for l in np.unique(mask.numpy()) if l != 0] 2395 2396 for lab in unique_labels: 2397 if verbose: 2398 print(f"Blending label: {lab}") 2399 2400 # --- Step 3: Super-resolve a dilated patch (provides context to the model) --- 2401 region_mask_dilated = ants.threshold_image(mask, lab, lab).iMath("MD", dilation_amount) 2402 cropped_lowres = ants.crop_image(image, region_mask_dilated) 2403 if cropped_lowres.shape[0] == 0: 2404 continue 2405 2406 # Apply the model to the cropped low-res patch 2407 sr_patch = apply_super_resolution_model_to_image( 2408 cropped_lowres, super_res_model, target_range=[0, 1] 2409 ) 2410 2411 # Place the super-resolved patch back onto a full-sized canvas 2412 sr_patch_full_size = ants.decrop_image(sr_patch, accumulator) 2413 2414 # --- Step 4: Create a smooth weight map for this region --- 2415 # We use the *non-dilated* mask for the weight map to ensure a sharp focus on the target region. 2416 region_mask_original = ants.threshold_image(mask, lab, lab) 2417 2418 # Resample the original region mask to the high-res grid 2419 weight_map = ants.resample_image(region_mask_original, new_size, use_voxels=True, interp_type=0) 2420 weight_map = ants.smooth_image(weight_map, sigma=2.0, 2421 sigma_in_physical_coordinates=False) 2422 if normalize_weight_maps: 2423 weight_map = ants.iMath(weight_map, "Normalize") 2424 # --- Step 5: Accumulate the weighted values and the weights themselves --- 2425 accumulator += sr_patch_full_size * weight_map 2426 weight_sum += weight_map 2427 2428 # --- Step 6: Final Combination --- 2429 # Normalize the accumulator by the total weight at each pixel 2430 weight_sum_np = weight_sum.numpy() 2431 accumulator_np = accumulator.numpy() 2432 2433 # Create a mask of pixels where blending occurred 2434 blended_mask = weight_sum_np > 0.0 # Use a small epsilon for float safety 2435 2436 # Start with the original upsampled image as the base 2437 final_image_np = background_sr_image.numpy() 2438 2439 # Perform the weighted average only where weights are non-zero 2440 final_image_np[blended_mask] = accumulator_np[blended_mask] / weight_sum_np[blended_mask] 2441 2442 # Re-insert any non-blended background regions that were processed 2443 # This handles cases where regions overlap; the weighted average takes care of it. 2444 2445 return ants.from_numpy(final_image_np, origin=background_sr_image.origin, 2446 spacing=background_sr_image.spacing, direction=background_sr_image.direction) 2447 2448 2449def inference( 2450 image, 2451 mdl, 2452 truncation=None, 2453 segmentation=None, 2454 target_range=[1, 0], 2455 poly_order='hist', 2456 dilation_amount=0, 2457 verbose=False): 2458 """ 2459 Perform super-resolution inference on an input image, optionally guided by segmentation. 2460 2461 This function uses a trained deep learning model to enhance the resolution of a medical image. 2462 It optionally applies label-wise inference if a segmentation mask is provided. 2463 2464 Parameters 2465 ---------- 2466 image : ants.ANTsImage 2467 Input image to be super-resolved. 2468 2469 mdl : keras.Model 2470 Trained super-resolution model, typically from ANTsPyNet. 2471 2472 truncation : tuple or list of float, optional 2473 Percentile values (e.g., [0.01, 0.99]) for intensity truncation before model input. 2474 If None, no truncation is applied. 2475 2476 segmentation : ants.ANTsImage, optional 2477 A labeled segmentation mask. If provided, super-resolution is performed per label 2478 using `region_wise_super_resolution` or `super_resolution_segmentation_per_label`. 2479 2480 target_range : list of float 2481 Intensity range used for scaling the input before applying the model. 2482 Default is [1, 0] (internal default for `apply_super_resolution_model_to_image`). 2483 2484 poly_order : int, str or None 2485 Determines how to match intensity between the super-resolved image and the original. 2486 Options: 2487 - 'hist' : use histogram matching 2488 - int >= 1 : perform polynomial regression of this order 2489 - None : no intensity adjustment 2490 2491 dilation_amount : int 2492 Number of dilation steps applied to each segmentation label during 2493 region-based super-resolution (if segmentation is provided). 2494 2495 verbose : bool 2496 If True, print progress and status messages. 2497 2498 Returns 2499 ------- 2500 ANTsImage or dict 2501 - If `segmentation` is None, returns a single ANTsImage (super-resolved image). 2502 - If `segmentation` is provided, returns a dictionary with: 2503 - 'super_resolution': ANTsImage 2504 - other entries may include label-wise results or metadata. 2505 2506 Examples 2507 -------- 2508 >>> import ants 2509 >>> import antspynet 2510 >>> from siq import inference 2511 >>> img = ants.image_read("lowres.nii.gz") 2512 >>> model = antspynet.get_pretrained_network("dbpn", target_suffix="T1") 2513 >>> srimg = inference(img, model, truncation=[0.01, 0.99], verbose=True) 2514 2515 >>> seg = ants.image_read("mask.nii.gz") 2516 >>> sr_result = inference(img, model, segmentation=seg) 2517 >>> srimg = sr_result['super_resolution'] 2518 """ 2519 import ants 2520 import numpy as np 2521 import antspynet 2522 import antspyt1w 2523 from siq import region_wise_super_resolution 2524 2525 def apply_intensity_match(sr_image, reference_image, order, verbose=False): 2526 if order is None: 2527 return sr_image 2528 if verbose: 2529 print("Applying intensity match with", order) 2530 if order == 'hist': 2531 return ants.histogram_match_image(sr_image, reference_image) 2532 else: 2533 return ants.regression_match_image(sr_image, reference_image, poly_order=order) 2534 2535 pimg = ants.image_clone(image) 2536 if truncation is not None: 2537 pimg = ants.iMath(pimg, 'TruncateIntensity', truncation[0], truncation[1]) 2538 2539 input_shape = mdl.inputs[0].shape 2540 num_channels = int(input_shape[-1]) 2541 2542 if segmentation is not None: 2543 if num_channels == 1: 2544 if verbose: 2545 print("Using region-wise super resolution due to single-channel model with segmentation.") 2546 sr = region_wise_super_resolution_blended( 2547 pimg, segmentation, mdl, 2548 dilation_amount=dilation_amount, 2549 verbose=verbose 2550 ) 2551 ref = ants.resample_image_to_target(pimg, sr) 2552 return apply_intensity_match(sr, ref, poly_order, verbose) 2553 else: 2554 mynp = segmentation.numpy() 2555 mynp = list(np.unique(mynp)[1:len(mynp)].astype(int)) 2556 upFactor = [] 2557 if len(input_shape) == 5: 2558 testarr = np.zeros([1, 8, 8, 8, 2]) 2559 testarrout = mdl(testarr) 2560 for k in range(3): 2561 upFactor.append(int(testarrout.shape[k + 1] / testarr.shape[k + 1])) 2562 elif len(input_shape) == 4: 2563 testarr = np.zeros([1, 8, 8, 2]) 2564 testarrout = mdl(testarr) 2565 for k in range(2): 2566 upFactor.append(int(testarrout.shape[k + 1] / testarr.shape[k + 1])) 2567 temp = antspyt1w.super_resolution_segmentation_per_label( 2568 pimg, 2569 segmentation, 2570 upFactor, 2571 mdl, 2572 segmentation_numbers=mynp, 2573 target_range=target_range, 2574 dilation_amount=dilation_amount, 2575 poly_order=poly_order, 2576 max_lab_plus_one=True 2577 ) 2578 imgsr = temp['super_resolution'] 2579 ref = ants.resample_image_to_target(pimg, imgsr) 2580 return apply_intensity_match(imgsr, ref, poly_order, verbose) 2581 2582 # Default path: no segmentation 2583 imgsr = antspynet.apply_super_resolution_model_to_image( 2584 pimg, mdl, target_range=target_range, regression_order=None, verbose=verbose 2585 ) 2586 ref = ants.resample_image_to_target(pimg, imgsr) 2587 return apply_intensity_match(imgsr, ref, poly_order, verbose)
31def get_data( name=None, force_download=False, version=0, target_extension='.csv' ): 32 """ 33 Get SIQ data filename 34 35 The first time this is called, it will download data to ~/.siq. 36 After, it will just read data from disk. The ~/.siq may need to 37 be periodically deleted in order to ensure data is current. 38 39 Arguments 40 --------- 41 name : string 42 name of data tag to retrieve 43 Options: 44 - 'all' 45 46 force_download: boolean 47 48 version: version of data to download (integer) 49 50 Returns 51 ------- 52 string 53 filepath of selected data 54 55 Example 56 ------- 57 >>> import siq 58 >>> siq.get_data() 59 """ 60 os.makedirs(DATA_PATH, exist_ok=True) 61 62 def download_data( version ): 63 url = "https://figshare.com/ndownloader/articles/16912366/versions/" + str(version) 64 target_file_name = "16912366.zip" 65 target_file_name_path = tf.keras.utils.get_file(target_file_name, url, 66 cache_subdir=DATA_PATH, extract = True ) 67 os.remove( DATA_PATH + target_file_name ) 68 69 if force_download: 70 download_data( version = version ) 71 72 73 files = [] 74 for fname in os.listdir(DATA_PATH): 75 if ( fname.endswith(target_extension) ) : 76 fname = os.path.join(DATA_PATH, fname) 77 files.append(fname) 78 79 if len( files ) == 0 : 80 download_data( version = version ) 81 for fname in os.listdir(DATA_PATH): 82 if ( fname.endswith(target_extension) ) : 83 fname = os.path.join(DATA_PATH, fname) 84 files.append(fname) 85 86 if name == 'all': 87 return files 88 89 datapath = None 90 91 for fname in os.listdir(DATA_PATH): 92 mystem = (Path(fname).resolve().stem) 93 mystem = (Path(mystem).resolve().stem) 94 mystem = (Path(mystem).resolve().stem) 95 if ( name == mystem and fname.endswith(target_extension) ) : 96 datapath = os.path.join(DATA_PATH, fname) 97 98 return datapath
Get SIQ data filename
The first time this is called, it will download data to ~/.siq. After, it will just read data from disk. The ~/.siq may need to be periodically deleted in order to ensure data is current.
Arguments
name : string name of data tag to retrieve Options: - 'all'
force_download: boolean
version: version of data to download (integer)
Returns
string filepath of selected data
Example
>>> import siq
>>> siq.get_data()
119def dbpn(input_image_size, 120 number_of_outputs=1, 121 number_of_base_filters=64, 122 number_of_feature_filters=256, 123 number_of_back_projection_stages=7, 124 convolution_kernel_size=(12, 12), 125 strides=(8, 8), 126 last_convolution=(3, 3), 127 number_of_loss_functions=1, 128 interpolation = 'nearest' 129 ): 130 """ 131 Creates a Deep Back-Projection Network (DBPN) for single image super-resolution. 132 133 This function constructs a Keras model based on the DBPN architecture, which 134 can be configured for either 2D or 3D inputs. The network uses iterative 135 up- and down-projection blocks to refine the high-resolution image estimate. A 136 key modification from the original paper is the option to use standard 137 interpolation for upsampling instead of deconvolution layers. 138 139 Reference: 140 - Haris, M., Shakhnarovich, G., & Ukita, N. (2018). Deep Back-Projection 141 Networks For Super-Resolution. In CVPR. 142 143 Parameters 144 ---------- 145 input_image_size : tuple or list 146 The shape of the input image, including the channel. 147 e.g., `(None, None, 1)` for 2D or `(None, None, None, 1)` for 3D. 148 149 number_of_outputs : int, optional 150 The number of channels in the output image. Default is 1. 151 152 number_of_base_filters : int, optional 153 The number of filters in the up/down projection blocks. Default is 64. 154 155 number_of_feature_filters : int, optional 156 The number of filters in the initial feature extraction layer. Default is 256. 157 158 number_of_back_projection_stages : int, optional 159 The number of iterative back-projection stages (T in the paper). Default is 7. 160 161 convolution_kernel_size : tuple or list, optional 162 The kernel size for the main projection convolutions. Should match the 163 dimensionality of the input. Default is (12, 12). 164 165 strides : tuple or list, optional 166 The strides for the up/down sampling operations, defining the 167 super-resolution factor. Default is (8, 8). 168 169 last_convolution : tuple or list, optional 170 The kernel size of the final reconstruction convolution. Default is (3, 3). 171 172 number_of_loss_functions : int, optional 173 If greater than 1, the model will have multiple identical output branches. 174 Typically set to 1. Default is 1. 175 176 interpolation : str, optional 177 The interpolation method to use for upsampling layers if not using 178 transposed convolutions. 'nearest' or 'bilinear'. Default is 'nearest'. 179 180 Returns 181 ------- 182 keras.Model 183 A Keras model implementing the DBPN architecture for the specified 184 parameters. 185 """ 186 idim = len( input_image_size ) - 1 187 if idim == 2: 188 myconv = Conv2D 189 myconv_transpose = Conv2DTranspose 190 myupsampling = UpSampling2D 191 shax = ( 1, 2 ) 192 firstConv = (3,3) 193 firstStrides=(1,1) 194 smashConv=(1,1) 195 if idim == 3: 196 myconv = Conv3D 197 myconv_transpose = Conv3DTranspose 198 myupsampling = UpSampling3D 199 shax = ( 1, 2, 3 ) 200 firstConv = (3,3,3) 201 firstStrides=(1,1,1) 202 smashConv=(1,1,1) 203 def up_block_2d(L, number_of_filters=64, kernel_size=(12, 12), strides=(8, 8), 204 include_dense_convolution_layer=True): 205 if include_dense_convolution_layer == True: 206 L = myconv(filters = number_of_filters, 207 use_bias=True, 208 kernel_size=smashConv, 209 strides=firstStrides, 210 padding='same')(L) 211 L = PReLU(alpha_initializer='zero', 212 shared_axes=shax)(L) 213 # Scale up 214 if idim == 2: 215 H0 = myupsampling( size = strides, interpolation=interpolation )(L) 216 if idim == 3: 217 H0 = myupsampling( size = strides )(L) 218 H0 = myconv(filters=number_of_filters, 219 kernel_size=firstConv, 220 strides=firstStrides, 221 use_bias=True, 222 padding='same')(H0) 223 H0 = PReLU(alpha_initializer='zero', 224 shared_axes=shax)(H0) 225 # Scale down 226 L0 = myconv(filters=number_of_filters, 227 kernel_size=kernel_size, 228 strides=strides, 229 kernel_initializer='glorot_uniform', 230 padding='same')(H0) 231 L0 = PReLU(alpha_initializer='zero', 232 shared_axes=shax)(L0) 233 # Residual 234 E = Subtract()([L0, L]) 235 # Scale residual up 236 if idim == 2: 237 H1 = myupsampling( size = strides, interpolation=interpolation )(E) 238 if idim == 3: 239 H1 = myupsampling( size = strides )(E) 240 H1 = myconv(filters=number_of_filters, 241 kernel_size=firstConv, 242 strides=firstStrides, 243 use_bias=True, 244 padding='same')(H1) 245 H1 = PReLU(alpha_initializer='zero', 246 shared_axes=shax)(H1) 247 # Output feature map 248 up_block = Add()([H0, H1]) 249 return up_block 250 def down_block_2d(H, number_of_filters=64, kernel_size=(12, 12), strides=(8, 8), 251 include_dense_convolution_layer=True): 252 if include_dense_convolution_layer == True: 253 H = myconv(filters = number_of_filters, 254 use_bias=True, 255 kernel_size=smashConv, 256 strides=firstStrides, 257 padding='same')(H) 258 H = PReLU(alpha_initializer='zero', 259 shared_axes=shax)(H) 260 # Scale down 261 L0 = myconv(filters=number_of_filters, 262 kernel_size=kernel_size, 263 strides=strides, 264 kernel_initializer='glorot_uniform', 265 padding='same')(H) 266 L0 = PReLU(alpha_initializer='zero', 267 shared_axes=shax)(L0) 268 # Scale up 269 if idim == 2: 270 H0 = myupsampling( size = strides, interpolation=interpolation )(L0) 271 if idim == 3: 272 H0 = myupsampling( size = strides )(L0) 273 H0 = myconv(filters=number_of_filters, 274 kernel_size=firstConv, 275 strides=firstStrides, 276 use_bias=True, 277 padding='same')(H0) 278 H0 = PReLU(alpha_initializer='zero', 279 shared_axes=shax)(H0) 280 # Residual 281 E = Subtract()([H0, H]) 282 # Scale residual down 283 L1 = myconv(filters=number_of_filters, 284 kernel_size=kernel_size, 285 strides=strides, 286 kernel_initializer='glorot_uniform', 287 padding='same')(E) 288 L1 = PReLU(alpha_initializer='zero', 289 shared_axes=shax)(L1) 290 # Output feature map 291 down_block = Add()([L0, L1]) 292 return down_block 293 inputs = Input(shape=input_image_size) 294 # Initial feature extraction 295 model = myconv(filters=number_of_feature_filters, 296 kernel_size=firstConv, 297 strides=firstStrides, 298 padding='same', 299 kernel_initializer='glorot_uniform')(inputs) 300 model = PReLU(alpha_initializer='zero', 301 shared_axes=shax)(model) 302 # Feature smashing 303 model = myconv(filters=number_of_base_filters, 304 kernel_size=smashConv, 305 strides=firstStrides, 306 padding='same', 307 kernel_initializer='glorot_uniform')(model) 308 model = PReLU(alpha_initializer='zero', 309 shared_axes=shax)(model) 310 # Back projection 311 up_projection_blocks = [] 312 down_projection_blocks = [] 313 model = up_block_2d(model, number_of_filters=number_of_base_filters, 314 kernel_size=convolution_kernel_size, strides=strides) 315 up_projection_blocks.append(model) 316 for i in range(number_of_back_projection_stages): 317 if i == 0: 318 model = down_block_2d(model, number_of_filters=number_of_base_filters, 319 kernel_size=convolution_kernel_size, strides=strides) 320 down_projection_blocks.append(model) 321 model = up_block_2d(model, number_of_filters=number_of_base_filters, 322 kernel_size=convolution_kernel_size, strides=strides) 323 up_projection_blocks.append(model) 324 model = Concatenate()(up_projection_blocks) 325 else: 326 model = down_block_2d(model, number_of_filters=number_of_base_filters, 327 kernel_size=convolution_kernel_size, strides=strides, 328 include_dense_convolution_layer=True) 329 down_projection_blocks.append(model) 330 model = Concatenate()(down_projection_blocks) 331 model = up_block_2d(model, number_of_filters=number_of_base_filters, 332 kernel_size=convolution_kernel_size, strides=strides, 333 include_dense_convolution_layer=True) 334 up_projection_blocks.append(model) 335 model = Concatenate()(up_projection_blocks) 336 outputs = myconv(filters=number_of_outputs, 337 kernel_size=last_convolution, 338 strides=firstStrides, 339 padding = 'same', 340 kernel_initializer = "glorot_uniform")(model) 341 if number_of_loss_functions == 1: 342 deep_back_projection_network_model = Model(inputs=inputs, outputs=outputs) 343 else: 344 outputList=[] 345 for k in range(number_of_loss_functions): 346 outputList.append(outputs) 347 deep_back_projection_network_model = Model(inputs=inputs, outputs=outputList) 348 return deep_back_projection_network_model
Creates a Deep Back-Projection Network (DBPN) for single image super-resolution.
This function constructs a Keras model based on the DBPN architecture, which can be configured for either 2D or 3D inputs. The network uses iterative up- and down-projection blocks to refine the high-resolution image estimate. A key modification from the original paper is the option to use standard interpolation for upsampling instead of deconvolution layers.
Reference:
- Haris, M., Shakhnarovich, G., & Ukita, N. (2018). Deep Back-Projection Networks For Super-Resolution. In CVPR.
Parameters
input_image_size : tuple or list
The shape of the input image, including the channel.
e.g., (None, None, 1)
for 2D or (None, None, None, 1)
for 3D.
number_of_outputs : int, optional The number of channels in the output image. Default is 1.
number_of_base_filters : int, optional The number of filters in the up/down projection blocks. Default is 64.
number_of_feature_filters : int, optional The number of filters in the initial feature extraction layer. Default is 256.
number_of_back_projection_stages : int, optional The number of iterative back-projection stages (T in the paper). Default is 7.
convolution_kernel_size : tuple or list, optional The kernel size for the main projection convolutions. Should match the dimensionality of the input. Default is (12, 12).
strides : tuple or list, optional The strides for the up/down sampling operations, defining the super-resolution factor. Default is (8, 8).
last_convolution : tuple or list, optional The kernel size of the final reconstruction convolution. Default is (3, 3).
number_of_loss_functions : int, optional If greater than 1, the model will have multiple identical output branches. Typically set to 1. Default is 1.
interpolation : str, optional The interpolation method to use for upsampling layers if not using transposed convolutions. 'nearest' or 'bilinear'. Default is 'nearest'.
Returns
keras.Model A Keras model implementing the DBPN architecture for the specified parameters.
353def get_random_base_ind( full_dims, patchWidth, off=8 ): 354 """ 355 Generates a random top-left corner index for a patch. 356 357 This utility function computes a valid starting index (e.g., [x, y, z]) 358 for extracting a patch from a larger volume, ensuring the patch fits entirely 359 within the volume's boundaries, accounting for an offset. 360 361 Parameters 362 ---------- 363 full_dims : tuple or list 364 The dimensions of the full volume (e.g., img.shape). 365 366 patchWidth : tuple or list 367 The dimensions of the patch to be extracted. 368 369 off : int, optional 370 An offset from the edge of the volume to avoid sampling near borders. 371 Default is 8. 372 373 Returns 374 ------- 375 list 376 A list of integers representing the starting coordinates for the patch. 377 """ 378 baseInd = [None,None,None] 379 for k in range(3): 380 baseInd[k]=random.sample( range( off, full_dims[k]-1-patchWidth[k] ), 1 )[0] 381 return baseInd
Generates a random top-left corner index for a patch.
This utility function computes a valid starting index (e.g., [x, y, z]) for extracting a patch from a larger volume, ensuring the patch fits entirely within the volume's boundaries, accounting for an offset.
Parameters
full_dims : tuple or list The dimensions of the full volume (e.g., img.shape).
patchWidth : tuple or list The dimensions of the patch to be extracted.
off : int, optional An offset from the edge of the volume to avoid sampling near borders. Default is 8.
Returns
list A list of integers representing the starting coordinates for the patch.
385def get_random_patch( img, patchWidth ): 386 """ 387 Extracts a random patch from an image with non-zero variance. 388 389 This function repeatedly samples a random patch of a specified width from 390 the input image until it finds one where the standard deviation of pixel 391 intensities is greater than zero. This is useful for avoiding blank or 392 uniform patches during training data generation. 393 394 Parameters 395 ---------- 396 img : ants.ANTsImage 397 The source image from which to extract a patch. 398 399 patchWidth : tuple or list 400 The desired dimensions of the output patch. 401 402 Returns 403 ------- 404 ants.ANTsImage 405 A randomly extracted patch from the input image. 406 """ 407 mystd = 0 408 while mystd == 0: 409 inds = get_random_base_ind( full_dims = img.shape, patchWidth=patchWidth, off=8 ) 410 hinds = [None,None,None] 411 for k in range(len(inds)): 412 hinds[k] = inds[k] + patchWidth[k] 413 myimg = ants.crop_indices( img, inds, hinds ) 414 mystd = myimg.std() 415 return myimg
Extracts a random patch from an image with non-zero variance.
This function repeatedly samples a random patch of a specified width from the input image until it finds one where the standard deviation of pixel intensities is greater than zero. This is useful for avoiding blank or uniform patches during training data generation.
Parameters
img : ants.ANTsImage The source image from which to extract a patch.
patchWidth : tuple or list The desired dimensions of the output patch.
Returns
ants.ANTsImage A randomly extracted patch from the input image.
417def get_random_patch_pair( img, img2, patchWidth ): 418 """ 419 Extracts a corresponding random patch from a pair of images. 420 421 This function finds a single random location and extracts a patch of the 422 same size and position from two different input images. It ensures that 423 both extracted patches have non-zero variance. This is useful for creating 424 paired training data (e.g., low-res and high-res images). 425 426 Parameters 427 ---------- 428 img : ants.ANTsImage 429 The first source image. 430 431 img2 : ants.ANTsImage 432 The second source image, spatially aligned with the first. 433 434 patchWidth : tuple or list 435 The desired dimensions of the output patches. 436 437 Returns 438 ------- 439 tuple of ants.ANTsImage 440 A tuple containing two corresponding patches: (patch_from_img, patch_from_img2). 441 """ 442 mystd = mystd2 = 0 443 ct = 0 444 while mystd == 0 or mystd2 == 0: 445 inds = get_random_base_ind( full_dims = img.shape, patchWidth=patchWidth, off=8 ) 446 hinds = [None,None,None] 447 for k in range(len(inds)): 448 hinds[k] = inds[k] + patchWidth[k] 449 myimg = ants.crop_indices( img, inds, hinds ) 450 myimg2 = ants.crop_indices( img2, inds, hinds ) 451 mystd = myimg.std() 452 mystd2 = myimg2.std() 453 ct = ct + 1 454 if ( ct > 20 ): 455 return myimg, myimg2 456 return myimg, myimg2
Extracts a corresponding random patch from a pair of images.
This function finds a single random location and extracts a patch of the same size and position from two different input images. It ensures that both extracted patches have non-zero variance. This is useful for creating paired training data (e.g., low-res and high-res images).
Parameters
img : ants.ANTsImage The first source image.
img2 : ants.ANTsImage The second source image, spatially aligned with the first.
patchWidth : tuple or list The desired dimensions of the output patches.
Returns
tuple of ants.ANTsImage A tuple containing two corresponding patches: (patch_from_img, patch_from_img2).
458def pseudo_3d_vgg_features( inshape = [128,128,128], layer = 4, angle=0, pretrained=True, verbose=False ): 459 """ 460 Creates a pseudo-3D VGG feature extractor from a pre-trained 2D VGG model. 461 462 This function constructs a 3D VGG-style network and initializes its weights 463 by "stretching" the weights from a pre-trained 2D VGG19 model (trained on 464 ImageNet) along a specified axis. This is a technique to transfer 2D 465 perceptual knowledge to a 3D domain for tasks like perceptual loss. 466 467 Parameters 468 ---------- 469 inshape : list of int, optional 470 The input shape of the 3D volume, e.g., `[128, 128, 128]`. Default is `[128,128,128]`. 471 472 layer : int, optional 473 The block number of the VGG network from which to extract features. For 474 VGG19, this corresponds to block `layer` (e.g., layer=4 means 'block4_conv...'). 475 Default is 4. 476 477 angle : int, optional 478 The axis along which to project the 2D weights: 479 - 0: Axial plane (stretches along Z) 480 - 1: Coronal plane (stretches along Y) 481 - 2: Sagittal plane (stretches along X) 482 Default is 0. 483 484 pretrained : bool, optional 485 If True, loads the stretched ImageNet weights. If False, the model is 486 randomly initialized. Default is True. 487 488 verbose : bool, optional 489 If True, prints information about the layers being used. Default is False. 490 491 Returns 492 ------- 493 tf.keras.Model 494 A Keras model that takes a 3D volume as input and outputs the pseudo-3D 495 feature map from the specified layer and angle. 496 """ 497 def getLayerScaleFactorForTransferLearning( k, w3d, w2d ): 498 myfact = np.round( np.prod( w3d[k].shape ) / np.prod( w2d[k].shape) ) 499 return myfact 500 vgg19 = tf.keras.applications.VGG19( 501 include_top = False, weights = "imagenet", 502 input_shape = [inshape[0],inshape[1],3], 503 classes = 1000 ) 504 def findLayerIndex( layerName, mdl ): 505 for k in range( len( mdl.layers ) ): 506 if layerName == mdl.layers[k].name : 507 return k - 1 508 return None 509 layer_index = layer-1 # findLayerIndex( 'block2_conv2', vgg19 ) 510 vggmodelRaw = antspynet.create_vgg_model_3d( 511 [inshape[0],inshape[1],inshape[2],1], 512 number_of_classification_labels = 1000, 513 layers = [1, 2, 3, 4, 4], 514 lowest_resolution = 64, 515 convolution_kernel_size= (3, 3, 3), pool_size = (2, 2, 2), 516 strides = (2, 2, 2), number_of_dense_units= 4096, dropout_rate = 0, 517 style = 19, mode = "classification") 518 if verbose: 519 print( vggmodelRaw.layers[layer_index] ) 520 print( vggmodelRaw.layers[layer_index].name ) 521 print( vgg19.layers[layer_index] ) 522 print( vgg19.layers[layer_index].name ) 523 feature_extractor_2d = tf.keras.Model( 524 inputs = vgg19.input, 525 outputs = vgg19.layers[layer_index].output) 526 feature_extractor = tf.keras.Model( 527 inputs = vggmodelRaw.input, 528 outputs = vggmodelRaw.layers[layer_index].output) 529 wts_2d = feature_extractor_2d.weights 530 wts = feature_extractor.weights 531 def checkwtshape( a, b ): 532 if len(a.shape) != len(b.shape): 533 return False 534 for j in range(len(a.shape)): 535 if a.shape[j] != b.shape[j]: 536 return False 537 return True 538 for ww in range(len(wts)): 539 wts[ww]=wts[ww].numpy() 540 wts_2d[ww]=wts_2d[ww].numpy() 541 if checkwtshape( wts[ww], wts_2d[ww] ) and ww != 0: 542 wts[ww]=wts_2d[ww] 543 elif ww != 0: 544 # FIXME - should allow doing this across different angles 545 if angle == 0: 546 wts[ww][:,:,0,:,:]=wts_2d[ww]/3.0 547 wts[ww][:,:,1,:,:]=wts_2d[ww]/3.0 548 wts[ww][:,:,2,:,:]=wts_2d[ww]/3.0 549 if angle == 1: 550 wts[ww][:,0,:,:,:]=wts_2d[ww]/3.0 551 wts[ww][:,1,:,:,:]=wts_2d[ww]/3.0 552 wts[ww][:,2,:,:,:]=wts_2d[ww]/3.0 553 if angle == 2: 554 wts[ww][0,:,:,:,:]=wts_2d[ww]/3.0 555 wts[ww][1,:,:,:,:]=wts_2d[ww]/3.0 556 wts[ww][2,:,:,:,:]=wts_2d[ww]/3.0 557 else: 558 wts[ww][:,:,:,0,:]=wts_2d[ww] 559 if pretrained: 560 feature_extractor.set_weights( wts ) 561 newinput = tf.keras.layers.Rescaling( 255.0, -127.5 )( feature_extractor.input ) 562 feature_extractor2 = feature_extractor( newinput ) 563 feature_extractor = tf.keras.Model( feature_extractor.input, feature_extractor2 ) 564 return feature_extractor
Creates a pseudo-3D VGG feature extractor from a pre-trained 2D VGG model.
This function constructs a 3D VGG-style network and initializes its weights by "stretching" the weights from a pre-trained 2D VGG19 model (trained on ImageNet) along a specified axis. This is a technique to transfer 2D perceptual knowledge to a 3D domain for tasks like perceptual loss.
Parameters
inshape : list of int, optional
The input shape of the 3D volume, e.g., [128, 128, 128]
. Default is [128,128,128]
.
layer : int, optional
The block number of the VGG network from which to extract features. For
VGG19, this corresponds to block layer
(e.g., layer=4 means 'block4_conv...').
Default is 4.
angle : int, optional The axis along which to project the 2D weights: - 0: Axial plane (stretches along Z) - 1: Coronal plane (stretches along Y) - 2: Sagittal plane (stretches along X) Default is 0.
pretrained : bool, optional If True, loads the stretched ImageNet weights. If False, the model is randomly initialized. Default is True.
verbose : bool, optional If True, prints information about the layers being used. Default is False.
Returns
tf.keras.Model A Keras model that takes a 3D volume as input and outputs the pseudo-3D feature map from the specified layer and angle.
566def pseudo_3d_vgg_features_unbiased( inshape = [128,128,128], layer = 4, verbose=False ): 567 """ 568 Create a pseudo-3D VGG-style feature extractor by aggregating axial, coronal, 569 and sagittal VGG feature representations. 570 571 This model extracts features along each principal axis using pre-trained 2D 572 VGG-style networks and concatenates them to form an unbiased pseudo-3D feature space. 573 574 Parameters 575 ---------- 576 inshape : list of int, optional 577 The input shape of the 3D volume, default is [128, 128, 128]. 578 579 layer : int, optional 580 The VGG feature layer to extract. Higher values correspond to deeper 581 layers in the pseudo-3D VGG backbone. 582 583 verbose : bool, optional 584 If True, prints debug messages during model construction. 585 586 Returns 587 ------- 588 tf.keras.Model 589 A TensorFlow Keras model that takes a 3D input volume and outputs the 590 concatenated pseudo-3D feature representation from the specified layer. 591 592 Notes 593 ----- 594 This is useful for perceptual loss or feature comparison in super-resolution 595 and image synthesis tasks. The same input is processed in three anatomical 596 planes (axial, coronal, sagittal), and features are concatenated. 597 598 See Also 599 -------- 600 pseudo_3d_vgg_features : Generates VGG features from a single anatomical plane. 601 """ 602 f = [ 603 pseudo_3d_vgg_features( inshape, layer, angle=0, pretrained=True, verbose=verbose ), 604 pseudo_3d_vgg_features( inshape, layer, angle=1, pretrained=True ), 605 pseudo_3d_vgg_features( inshape, layer, angle=2, pretrained=True ) ] 606 f1=f[0].inputs 607 f0o=f[0]( f1 ) 608 f1o=f[1]( f1 ) 609 f2o=f[2]( f1 ) 610 catter = tf.keras.layers.concatenate( [f0o, f1o, f2o ]) 611 feature_extractor = tf.keras.Model( f1, catter ) 612 return feature_extractor
Create a pseudo-3D VGG-style feature extractor by aggregating axial, coronal, and sagittal VGG feature representations.
This model extracts features along each principal axis using pre-trained 2D VGG-style networks and concatenates them to form an unbiased pseudo-3D feature space.
Parameters
inshape : list of int, optional The input shape of the 3D volume, default is [128, 128, 128].
layer : int, optional The VGG feature layer to extract. Higher values correspond to deeper layers in the pseudo-3D VGG backbone.
verbose : bool, optional If True, prints debug messages during model construction.
Returns
tf.keras.Model A TensorFlow Keras model that takes a 3D input volume and outputs the concatenated pseudo-3D feature representation from the specified layer.
Notes
This is useful for perceptual loss or feature comparison in super-resolution and image synthesis tasks. The same input is processed in three anatomical planes (axial, coronal, sagittal), and features are concatenated.
See Also
pseudo_3d_vgg_features : Generates VGG features from a single anatomical plane.
614def get_grader_feature_network( layer=6 ): 615 """ 616 Load and extract a ResNet-based feature subnetwork for perceptual loss or quality grading. 617 618 This function loads a pre-trained 3D ResNet model ("grader") used for 619 perceptual feature extraction and returns a subnetwork that outputs activations 620 from a specified internal layer. 621 622 Parameters 623 ---------- 624 layer : int, optional 625 The index of the internal ResNet layer whose output should be used as 626 the feature representation. Default is layer 6. 627 628 Returns 629 ------- 630 tf.keras.Model 631 A Keras model that outputs features from the specified layer of the 632 pre-trained 3D ResNet grader model. 633 634 Raises 635 ------ 636 Exception 637 If the pre-trained weights file (`resnet_grader.h5`) is not found. 638 639 Notes 640 ----- 641 The pre-trained weights should be located in: `~/.antspyt1w/resnet_grader.keras` 642 643 This model is typically used to compute perceptual loss by comparing 644 intermediate activations between target and prediction volumes. 645 646 See Also 647 -------- 648 antspynet.create_resnet_model_3d : Constructs the base ResNet model. 649 """ 650 grader = antspynet.create_resnet_model_3d( 651 [None,None,None,1], 652 lowest_resolution = 32, 653 number_of_outputs = 4, 654 cardinality = 1, 655 squeeze_and_excite = False ) 656 # the folder and data below as available from antspyt1w get_data 657 graderfn = os.path.expanduser( "~/.antspyt1w/resnet_grader.h5" ) 658 if not exists( graderfn ): 659 raise Exception("graderfn " + graderfn + " does not exist") 660 grader.load_weights( graderfn) 661 # feature_extractor_23 = tf.keras.Model( inputs=grader.inputs, outputs=grader.layers[23].output ) 662 # feature_extractor_44 = tf.keras.Model( inputs=grader.inputs, outputs=grader.layers[44].output ) 663 return tf.keras.Model( inputs=grader.inputs, outputs=grader.layers[layer].output )
Load and extract a ResNet-based feature subnetwork for perceptual loss or quality grading.
This function loads a pre-trained 3D ResNet model ("grader") used for perceptual feature extraction and returns a subnetwork that outputs activations from a specified internal layer.
Parameters
layer : int, optional The index of the internal ResNet layer whose output should be used as the feature representation. Default is layer 6.
Returns
tf.keras.Model A Keras model that outputs features from the specified layer of the pre-trained 3D ResNet grader model.
Raises
Exception
If the pre-trained weights file (resnet_grader.h5
) is not found.
Notes
The pre-trained weights should be located in: ~/.antspyt1w/resnet_grader.keras
This model is typically used to compute perceptual loss by comparing intermediate activations between target and prediction volumes.
See Also
antspynet.create_resnet_model_3d : Constructs the base ResNet model.
666def default_dbpn( 667 strider, # length should equal dimensionality 668 dimensionality = 3, 669 nfilt=64, 670 nff = 256, 671 convn = 6, 672 lastconv = 3, 673 nbp=7, 674 nChannelsIn=1, 675 nChannelsOut=1, 676 option = None, 677 intensity_model=None, 678 segmentation_model=None, 679 sigmoid_second_channel=False, 680 clone_intensity_to_segmentation=False, 681 pro_seg = 0, 682 freeze = False, 683 verbose=False 684 ): 685 """ 686 Constructs a DBPN model based on input parameters, and can optionally 687 use external models for intensity or segmentation processing. 688 689 Args: 690 strider (list): List of strides, length must match `dimensionality`. 691 dimensionality (int): Number of dimensions (2 or 3). Default is 3. 692 nfilt (int): Number of base filters. Default is 64. 693 nff (int): Number of feature filters. Default is 256. 694 convn (int): Convolution kernel size. Default is 6. 695 lastconv (int): Size of the last convolution. Default is 3. 696 nbp (int): Number of back projection stages. Default is 7. 697 nChannelsIn (int): Number of input channels. Default is 1. 698 nChannelsOut (int): Number of output channels. Default is 1. 699 option (str): Model size option ('tiny', 'small', 'medium', 'large'). Default is None. 700 intensity_model (tf.keras.Model): Optional external intensity model. 701 segmentation_model (tf.keras.Model): Optional external segmentation model. 702 sigmoid_second_channel (bool): If True, applies sigmoid to second channel in output. 703 clone_intensity_to_segmentation (bool): If True, clones intensity model weights to segmentation. 704 pro_seg (int): If greater than 0, adds a segmentation arm. 705 freeze (bool): If True, freezes the layers of the intensity/segmentation models. 706 verbose (bool): If True, prints detailed logs. 707 708 Returns: 709 Model: A Keras model based on the specified configuration. 710 711 Raises: 712 Exception: If `len(strider)` is not equal to `dimensionality`. 713 """ 714 if option == 'tiny': 715 nfilt=32 716 nff = 64 717 convn = 3 718 lastconv = 1 719 nbp=2 720 elif option == 'small': 721 nfilt=32 722 nff = 64 723 convn = 6 724 lastconv = 3 725 nbp=4 726 elif option == 'medium': 727 nfilt=64 728 nff = 128 729 convn = 6 730 lastconv = 3 731 nbp=4 732 else: 733 option='large' 734 if verbose: 735 print("Building mode of size: " + option) 736 if intensity_model is not None: 737 print("user-passed intensity model will be frozen - only segmentation will train") 738 if segmentation_model is not None: 739 print("user-passed segmentation model will be frozen - only intensity will train") 740 741 if len(strider) != dimensionality: 742 raise Exception("len(strider) != dimensionality") 743 # **model instantiation**: these are close to defaults for the 2x network.<br> 744 # empirical evidence suggests that making covolutions and strides evenly<br> 745 # divisible by each other reduces artifacts. 2*3=6. 746 # ofn='./models/dsr3d_'+str(strider)+'up_' + str(nfilt) + '_' + str( nff ) + '_' + str(convn)+ '_' + str(lastconv)+ '_' + str(os.environ['CUDA_VISIBLE_DEVICES'])+'_v0.0.keras' 747 if dimensionality == 2 : 748 mdl = dbpn( (None,None,nChannelsIn), 749 number_of_outputs=nChannelsOut, 750 number_of_base_filters=nfilt, 751 number_of_feature_filters=nff, 752 number_of_back_projection_stages=nbp, 753 convolution_kernel_size=(convn, convn), 754 strides=(strider[0], strider[1]), 755 last_convolution=(lastconv, lastconv), 756 number_of_loss_functions=1, 757 interpolation='nearest') 758 if dimensionality == 3 : 759 mdl = dbpn( (None,None,None,nChannelsIn), 760 number_of_outputs=nChannelsOut, 761 number_of_base_filters=nfilt, 762 number_of_feature_filters=nff, 763 number_of_back_projection_stages=nbp, 764 convolution_kernel_size=(convn, convn, convn), 765 strides=(strider[0], strider[1], strider[2]), 766 last_convolution=(lastconv, lastconv, lastconv), number_of_loss_functions=1, interpolation='nearest') 767 if sigmoid_second_channel and pro_seg != 0 : 768 if dimensionality == 2 : 769 input_image_size = (None,None,2) 770 if intensity_model is None: 771 intensity_model = dbpn( (None,None,1), 772 number_of_outputs=1, 773 number_of_base_filters=nfilt, 774 number_of_feature_filters=nff, 775 number_of_back_projection_stages=nbp, 776 convolution_kernel_size=(convn, convn), 777 strides=(strider[0], strider[1]), 778 last_convolution=(lastconv, lastconv), 779 number_of_loss_functions=1, 780 interpolation='nearest') 781 else: 782 if freeze: 783 for layer in intensity_model.layers: 784 layer.trainable = False 785 if segmentation_model is None: 786 segmentation_model = dbpn( (None,None,1), 787 number_of_outputs=1, 788 number_of_base_filters=nfilt, 789 number_of_feature_filters=nff, 790 number_of_back_projection_stages=nbp, 791 convolution_kernel_size=(convn, convn), 792 strides=(strider[0], strider[1]), 793 last_convolution=(lastconv, lastconv), 794 number_of_loss_functions=1, interpolation='linear') 795 else: 796 if freeze: 797 for layer in segmentation_model.layers: 798 layer.trainable = False 799 if dimensionality == 3 : 800 input_image_size = (None,None,None,2) 801 if intensity_model is None: 802 intensity_model = dbpn( (None,None,None,1), 803 number_of_outputs=1, 804 number_of_base_filters=nfilt, 805 number_of_feature_filters=nff, 806 number_of_back_projection_stages=nbp, 807 convolution_kernel_size=(convn, convn, convn), 808 strides=(strider[0], strider[1], strider[2]), 809 last_convolution=(lastconv, lastconv, lastconv), 810 number_of_loss_functions=1, interpolation='nearest') 811 else: 812 if freeze: 813 for layer in intensity_model.layers: 814 layer.trainable = False 815 if segmentation_model is None: 816 segmentation_model = dbpn( (None,None,None,1), 817 number_of_outputs=1, 818 number_of_base_filters=nfilt, 819 number_of_feature_filters=nff, 820 number_of_back_projection_stages=nbp, 821 convolution_kernel_size=(convn, convn, convn), 822 strides=(strider[0], strider[1], strider[2]), 823 last_convolution=(lastconv, lastconv, lastconv), 824 number_of_loss_functions=1, interpolation='linear') 825 else: 826 if freeze: 827 for layer in segmentation_model.layers: 828 layer.trainable = False 829 if verbose: 830 print( "len intensity_model layers : " + str( len( intensity_model.layers ))) 831 print( "len intensity_model weights : " + str( len( intensity_model.weights ))) 832 print( "len segmentation_model layers : " + str( len( segmentation_model.layers ))) 833 print( "len segmentation_model weights : " + str( len( segmentation_model.weights ))) 834 if clone_intensity_to_segmentation: 835 for k in range(len( segmentation_model.weights )): 836 if k < len( intensity_model.weights ): 837 if intensity_model.weights[k].shape == segmentation_model.weights[k].shape: 838 segmentation_model.weights[k] = intensity_model.weights[k] 839 inputs = tf.keras.Input(shape=input_image_size) 840 insplit = tf.split( inputs, 2, dimensionality+1) 841 outputs = [ 842 intensity_model( insplit[0] ), 843 tf.nn.sigmoid( segmentation_model( insplit[1] ) ) ] 844 mdlout = tf.concat( outputs, axis=dimensionality+1 ) 845 return Model(inputs=inputs, outputs=mdlout ) 846 if pro_seg > 0 and intensity_model is not None: 847 if verbose and freeze: 848 print("Add a segmentation arm to the end. freeze intensity. intensity_model(seg) => conv => sigmoid") 849 if verbose and not freeze: 850 print("Add a segmentation arm to the end. freeze intensity. intensity_model(seg) => conv => sigmoid") 851 if freeze: 852 for layer in intensity_model.layers: 853 layer.trainable = False 854 if dimensionality == 2 : 855 input_image_size = (None,None,2) 856 elif dimensionality == 3 : 857 input_image_size = (None, None,None,2) 858 if dimensionality == 2: 859 myconv = Conv2D 860 firstConv = (convn,convn) 861 firstStrides=(1,1) 862 smashConv=(pro_seg,pro_seg) 863 if dimensionality == 3: 864 myconv = Conv3D 865 firstConv = (convn,convn,convn) 866 firstStrides=(1,1,1) 867 smashConv=(pro_seg,pro_seg,pro_seg) 868 inputs = tf.keras.Input(shape=input_image_size) 869 insplit = tf.split( inputs, 2, dimensionality+1) 870 # define segmentation arm 871 seggit = intensity_model( insplit[1] ) 872 L0 = myconv(filters=nff, 873 kernel_size=firstConv, 874 strides=firstStrides, 875 kernel_initializer='glorot_uniform', 876 padding='same')(seggit) 877 L1 = myconv(filters=nff, 878 kernel_size=firstConv, 879 strides=firstStrides, 880 kernel_initializer='glorot_uniform', 881 padding='same')(L0) 882 L2 = myconv(filters=1, 883 kernel_size=smashConv, 884 strides=firstStrides, 885 kernel_initializer='glorot_uniform', 886 padding='same')(L1) 887 outputs = [ 888 intensity_model( insplit[0] ), 889 tf.nn.sigmoid( L2 ) ] 890 mdlout = tf.concat( outputs, axis=dimensionality+1 ) 891 return Model(inputs=inputs, outputs=mdlout ) 892 return mdl
Constructs a DBPN model based on input parameters, and can optionally use external models for intensity or segmentation processing.
Args:
strider (list): List of strides, length must match dimensionality
.
dimensionality (int): Number of dimensions (2 or 3). Default is 3.
nfilt (int): Number of base filters. Default is 64.
nff (int): Number of feature filters. Default is 256.
convn (int): Convolution kernel size. Default is 6.
lastconv (int): Size of the last convolution. Default is 3.
nbp (int): Number of back projection stages. Default is 7.
nChannelsIn (int): Number of input channels. Default is 1.
nChannelsOut (int): Number of output channels. Default is 1.
option (str): Model size option ('tiny', 'small', 'medium', 'large'). Default is None.
intensity_model (tf.keras.Model): Optional external intensity model.
segmentation_model (tf.keras.Model): Optional external segmentation model.
sigmoid_second_channel (bool): If True, applies sigmoid to second channel in output.
clone_intensity_to_segmentation (bool): If True, clones intensity model weights to segmentation.
pro_seg (int): If greater than 0, adds a segmentation arm.
freeze (bool): If True, freezes the layers of the intensity/segmentation models.
verbose (bool): If True, prints detailed logs.
Returns: Model: A Keras model based on the specified configuration.
Raises:
Exception: If len(strider)
is not equal to dimensionality
.
894def image_patch_training_data_from_filenames( 895 filenames, 896 target_patch_size, 897 target_patch_size_low, 898 nPatches = 128, 899 istest = False, 900 patch_scaler=True, 901 to_tensorflow = False, 902 verbose = False 903 ): 904 """ 905 Generates a batch of paired high- and low-resolution image patches for training. 906 907 This function creates training data by taking a list of high-resolution source 908 images, extracting random patches, and then downsampling them to create 909 low-resolution counterparts. This provides the (input, ground_truth) pairs 910 needed to train a super-resolution model. 911 912 Parameters 913 ---------- 914 filenames : list of str 915 A list of file paths to the high-resolution source images. 916 917 target_patch_size : tuple or list of int 918 The dimensions of the high-resolution (ground truth) patch to extract, 919 e.g., `(128, 128, 128)`. 920 921 target_patch_size_low : tuple or list of int 922 The dimensions of the low-resolution (input) patch to generate. The ratio 923 between `target_patch_size` and `target_patch_size_low` determines the 924 super-resolution factor. 925 926 nPatches : int, optional 927 The number of patch pairs to generate in this batch. Default is 128. 928 929 istest : bool, optional 930 If True, the function also generates a third output array containing patches 931 that have been naively upsampled using linear interpolation. This is useful 932 for calculating baseline evaluation metrics (e.g., PSNR) against which the 933 model's performance can be compared. Default is False. 934 935 patch_scaler : bool, optional 936 If True, scales the intensity of each high-resolution patch to the [0, 1] 937 range before creating the downsampled version. This can help with 938 training stability. Default is True. 939 940 to_tensorflow : bool, optional 941 If True, casts the output NumPy arrays to TensorFlow tensors. Default is False. 942 943 verbose : bool, optional 944 If True, prints progress messages during patch generation. Default is False. 945 946 Returns 947 ------- 948 tuple 949 A tuple of NumPy arrays or TensorFlow tensors. 950 - If `istest` is False: `(patchesResam, patchesOrig)` 951 - `patchesResam`: The batch of low-resolution input patches (X_train). 952 - `patchesOrig`: The batch of high-resolution ground truth patches (y_train). 953 - If `istest` is True: `(patchesResam, patchesOrig, patchesUp)` 954 - `patchesUp`: The batch of baseline, linearly-upsampled patches. 955 """ 956 if verbose: 957 print("begin image_patch_training_data_from_filenames") 958 tardim = len( target_patch_size ) 959 strider = [] 960 for j in range( tardim ): 961 strider.append( np.round( target_patch_size[j]/target_patch_size_low[j]) ) 962 if tardim == 3: 963 shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],target_patch_size[2],1) 964 shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],target_patch_size_low[2],1) 965 if tardim == 2: 966 shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],1) 967 shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],1) 968 patchesOrig = np.zeros(shape=shaperhi) 969 patchesResam = np.zeros(shape=shaperlo) 970 patchesUp = None 971 if istest: 972 patchesUp = np.zeros(shape=patchesOrig.shape) 973 for myn in range(nPatches): 974 if verbose: 975 print(myn) 976 imgfn = random.sample( filenames, 1 )[0] 977 if verbose: 978 print(imgfn) 979 img = ants.image_read( imgfn ).iMath("Normalize") 980 if img.components > 1: 981 img = ants.split_channels(img)[0] 982 img = ants.crop_image( img, ants.threshold_image( img, 0.05, 1 ) ) 983 ants.set_origin( img, ants.get_center_of_mass(img) ) 984 img = ants.iMath(img,"Normalize") 985 spc = ants.get_spacing( img ) 986 newspc = [] 987 for jj in range(len(spc)): 988 newspc.append(spc[jj]*strider[jj]) 989 interp_type = random.choice( [0,1] ) 990 if True: 991 imgp = get_random_patch( img, target_patch_size ) 992 imgpmin = imgp.min() 993 if patch_scaler: 994 imgp = imgp - imgpmin 995 imgpmax = imgp.max() 996 if imgpmax > 0 : 997 imgp = imgp / imgpmax 998 rimgp = ants.resample_image( imgp, newspc, use_voxels = False, interp_type=interp_type ) 999 if istest: 1000 rimgbi = ants.resample_image( rimgp, spc, use_voxels = False, interp_type=0 ) 1001 if tardim == 3: 1002 patchesOrig[myn,:,:,:,0] = imgp.numpy() 1003 patchesResam[myn,:,:,:,0] = rimgp.numpy() 1004 if istest: 1005 patchesUp[myn,:,:,:,0] = rimgbi.numpy() 1006 if tardim == 2: 1007 patchesOrig[myn,:,:,0] = imgp.numpy() 1008 patchesResam[myn,:,:,0] = rimgp.numpy() 1009 if istest: 1010 patchesUp[myn,:,:,0] = rimgbi.numpy() 1011 if to_tensorflow: 1012 patchesOrig = tf.cast( patchesOrig, "float32") 1013 patchesResam = tf.cast( patchesResam, "float32") 1014 if istest: 1015 if to_tensorflow: 1016 patchesUp = tf.cast( patchesUp, "float32") 1017 return patchesResam, patchesOrig, patchesUp
Generates a batch of paired high- and low-resolution image patches for training.
This function creates training data by taking a list of high-resolution source images, extracting random patches, and then downsampling them to create low-resolution counterparts. This provides the (input, ground_truth) pairs needed to train a super-resolution model.
Parameters
filenames : list of str A list of file paths to the high-resolution source images.
target_patch_size : tuple or list of int
The dimensions of the high-resolution (ground truth) patch to extract,
e.g., (128, 128, 128)
.
target_patch_size_low : tuple or list of int
The dimensions of the low-resolution (input) patch to generate. The ratio
between target_patch_size
and target_patch_size_low
determines the
super-resolution factor.
nPatches : int, optional The number of patch pairs to generate in this batch. Default is 128.
istest : bool, optional If True, the function also generates a third output array containing patches that have been naively upsampled using linear interpolation. This is useful for calculating baseline evaluation metrics (e.g., PSNR) against which the model's performance can be compared. Default is False.
patch_scaler : bool, optional If True, scales the intensity of each high-resolution patch to the [0, 1] range before creating the downsampled version. This can help with training stability. Default is True.
to_tensorflow : bool, optional If True, casts the output NumPy arrays to TensorFlow tensors. Default is False.
verbose : bool, optional If True, prints progress messages during patch generation. Default is False.
Returns
tuple
A tuple of NumPy arrays or TensorFlow tensors.
- If istest
is False: (patchesResam, patchesOrig)
- patchesResam
: The batch of low-resolution input patches (X_train).
- patchesOrig
: The batch of high-resolution ground truth patches (y_train).
- If istest
is True: (patchesResam, patchesOrig, patchesUp)
- patchesUp
: The batch of baseline, linearly-upsampled patches.
1020def seg_patch_training_data_from_filenames( 1021 filenames, 1022 target_patch_size, 1023 target_patch_size_low, 1024 nPatches = 128, 1025 istest = False, 1026 patch_scaler=True, 1027 to_tensorflow = False, 1028 verbose = False 1029 ): 1030 """ 1031 Generates a batch of paired training data containing both images and segmentations. 1032 1033 This function extends `image_patch_training_data_from_filenames` by adding a 1034 second channel to the data. For each extracted image patch, it also generates 1035 a corresponding segmentation mask using Otsu's thresholding. This is useful for 1036 training multi-task models that perform super-resolution on both an image and 1037 its associated segmentation simultaneously. 1038 1039 Parameters 1040 ---------- 1041 filenames : list of str 1042 A list of file paths to the high-resolution source images. 1043 1044 target_patch_size : tuple or list of int 1045 The dimensions of the high-resolution patch, e.g., `(128, 128, 128)`. 1046 1047 target_patch_size_low : tuple or list of int 1048 The dimensions of the low-resolution input patch. 1049 1050 nPatches : int, optional 1051 The number of patch pairs to generate. Default is 128. 1052 1053 istest : bool, optional 1054 If True, also generates a third output array containing baseline upsampled 1055 intensity images (channel 0 only). Default is False. 1056 1057 patch_scaler : bool, optional 1058 If True, scales the intensity of each image patch to the [0, 1] range. 1059 Default is True. 1060 1061 to_tensorflow : bool, optional 1062 If True, casts the output NumPy arrays to TensorFlow tensors. Default is False. 1063 1064 verbose : bool, optional 1065 If True, prints progress messages. Default is False. 1066 1067 Returns 1068 ------- 1069 tuple 1070 A tuple of multi-channel NumPy arrays or TensorFlow tensors. The structure 1071 is the same as `image_patch_training_data_from_filenames`, but each 1072 array has a channel dimension of 2: 1073 - Channel 0: The intensity image. 1074 - Channel 1: The binary segmentation mask. 1075 """ 1076 if verbose: 1077 print("begin seg_patch_training_data_from_filenames") 1078 tardim = len( target_patch_size ) 1079 strider = [] 1080 nchan = 2 1081 for j in range( tardim ): 1082 strider.append( np.round( target_patch_size[j]/target_patch_size_low[j]) ) 1083 if tardim == 3: 1084 shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],target_patch_size[2],nchan) 1085 shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],target_patch_size_low[2],nchan) 1086 if tardim == 2: 1087 shaperhi = (nPatches,target_patch_size[0],target_patch_size[1],nchan) 1088 shaperlo = (nPatches,target_patch_size_low[0],target_patch_size_low[1],nchan) 1089 patchesOrig = np.zeros(shape=shaperhi) 1090 patchesResam = np.zeros(shape=shaperlo) 1091 patchesUp = None 1092 if istest: 1093 patchesUp = np.zeros(shape=patchesOrig.shape) 1094 for myn in range(nPatches): 1095 if verbose: 1096 print(myn) 1097 imgfn = random.sample( filenames, 1 )[0] 1098 if verbose: 1099 print(imgfn) 1100 img = ants.image_read( imgfn ).iMath("Normalize") 1101 if img.components > 1: 1102 img = ants.split_channels(img)[0] 1103 img = ants.crop_image( img, ants.threshold_image( img, 0.05, 1 ) ) 1104 ants.set_origin( img, ants.get_center_of_mass(img) ) 1105 img = ants.iMath(img,"Normalize") 1106 spc = ants.get_spacing( img ) 1107 newspc = [] 1108 for jj in range(len(spc)): 1109 newspc.append(spc[jj]*strider[jj]) 1110 interp_type = random.choice( [0,1] ) 1111 seg_class = random.choice( [1,2] ) 1112 if True: 1113 imgp = get_random_patch( img, target_patch_size ) 1114 imgpmin = imgp.min() 1115 if patch_scaler: 1116 imgp = imgp - imgpmin 1117 imgpmax = imgp.max() 1118 if imgpmax > 0 : 1119 imgp = imgp / imgpmax 1120 segp = ants.threshold_image( imgp, "Otsu", 2 ).threshold_image( seg_class, seg_class ) 1121 rimgp = ants.resample_image( imgp, newspc, use_voxels = False, interp_type=interp_type ) 1122 rsegp = ants.resample_image( segp, newspc, use_voxels = False, interp_type=interp_type ) 1123 if istest: 1124 rimgbi = ants.resample_image( rimgp, spc, use_voxels = False, interp_type=0 ) 1125 if tardim == 3: 1126 patchesOrig[myn,:,:,:,0] = imgp.numpy() 1127 patchesResam[myn,:,:,:,0] = rimgp.numpy() 1128 patchesOrig[myn,:,:,:,1] = segp.numpy() 1129 patchesResam[myn,:,:,:,1] = rsegp.numpy() 1130 if istest: 1131 patchesUp[myn,:,:,:,0] = rimgbi.numpy() 1132 if tardim == 2: 1133 patchesOrig[myn,:,:,0] = imgp.numpy() 1134 patchesResam[myn,:,:,0] = rimgp.numpy() 1135 patchesOrig[myn,:,:,1] = segp.numpy() 1136 patchesResam[myn,:,:,1] = rsegp.numpy() 1137 if istest: 1138 patchesUp[myn,:,:,0] = rimgbi.numpy() 1139 if to_tensorflow: 1140 patchesOrig = tf.cast( patchesOrig, "float32") 1141 patchesResam = tf.cast( patchesResam, "float32") 1142 if istest: 1143 if to_tensorflow: 1144 patchesUp = tf.cast( patchesUp, "float32") 1145 return patchesResam, patchesOrig, patchesUp
Generates a batch of paired training data containing both images and segmentations.
This function extends image_patch_training_data_from_filenames
by adding a
second channel to the data. For each extracted image patch, it also generates
a corresponding segmentation mask using Otsu's thresholding. This is useful for
training multi-task models that perform super-resolution on both an image and
its associated segmentation simultaneously.
Parameters
filenames : list of str A list of file paths to the high-resolution source images.
target_patch_size : tuple or list of int
The dimensions of the high-resolution patch, e.g., (128, 128, 128)
.
target_patch_size_low : tuple or list of int The dimensions of the low-resolution input patch.
nPatches : int, optional The number of patch pairs to generate. Default is 128.
istest : bool, optional If True, also generates a third output array containing baseline upsampled intensity images (channel 0 only). Default is False.
patch_scaler : bool, optional If True, scales the intensity of each image patch to the [0, 1] range. Default is True.
to_tensorflow : bool, optional If True, casts the output NumPy arrays to TensorFlow tensors. Default is False.
verbose : bool, optional If True, prints progress messages. Default is False.
Returns
tuple
A tuple of multi-channel NumPy arrays or TensorFlow tensors. The structure
is the same as image_patch_training_data_from_filenames
, but each
array has a channel dimension of 2:
- Channel 0: The intensity image.
- Channel 1: The binary segmentation mask.
1147def read( filename ): 1148 """ 1149 Reads an image or a NumPy array from a file. 1150 1151 This function acts as a wrapper to intelligently load data. It checks the 1152 file extension to decide whether to use `ants.image_read` for standard 1153 medical image formats (e.g., .nii.gz, .mha) or `numpy.load` for `.npy` files. 1154 1155 Parameters 1156 ---------- 1157 filename : str 1158 The full path to the file to be read. 1159 1160 Returns 1161 ------- 1162 ants.ANTsImage or np.ndarray 1163 The loaded data object, either as an ANTsImage or a NumPy array. 1164 """ 1165 import re 1166 isnpy = len( re.sub( ".npy", "", filename ) ) != len( filename ) 1167 if not isnpy: 1168 myoutput = ants.image_read( filename ) 1169 else: 1170 myoutput = np.load( filename ) 1171 return myoutput
Reads an image or a NumPy array from a file.
This function acts as a wrapper to intelligently load data. It checks the
file extension to decide whether to use ants.image_read
for standard
medical image formats (e.g., .nii.gz, .mha) or numpy.load
for .npy
files.
Parameters
filename : str The full path to the file to be read.
Returns
ants.ANTsImage or np.ndarray The loaded data object, either as an ANTsImage or a NumPy array.
1174def auto_weight_loss( mdl, feature_extractor, x, y, feature=2.0, tv=0.1, verbose=True ): 1175 """ 1176 Automatically compute weighting coefficients for a combined loss function 1177 based on intensity (MSE), perceptual similarity (feature), and total variation (TV). 1178 1179 Parameters 1180 ---------- 1181 mdl : tf.keras.Model 1182 A trained or untrained model to evaluate predictions on input `x`. 1183 1184 feature_extractor : tf.keras.Model 1185 A model that extracts intermediate features from the input. Commonly a VGG or ResNet 1186 trained on a perceptual task. 1187 1188 x : tf.Tensor 1189 Input batch to the model. 1190 1191 y : tf.Tensor 1192 Ground truth target for `x`, typically a batch of 2D or 3D volumes. 1193 1194 feature : float, optional 1195 Weighting factor for the feature (perceptual) term in the loss. Default is 2.0. 1196 1197 tv : float, optional 1198 Weighting factor for the total variation term in the loss. Default is 0.1. 1199 1200 verbose : bool, optional 1201 If True, prints each component of the loss and its scaled value. 1202 1203 Returns 1204 ------- 1205 list of float 1206 A list of computed weights in the order: 1207 `[msq_weight, feature_weight, tv_weight]` 1208 1209 Notes 1210 ----- 1211 The total loss (to be used during training) can then be constructed as: 1212 1213 `L = msq_weight * MSE + feature_weight * perceptual_loss + tv_weight * TV` 1214 1215 This function is typically used to balance loss terms before training. 1216 """ 1217 y_pred = mdl( x ) 1218 squared_difference = tf.square( y - y_pred) 1219 if len( y.shape ) == 5: 1220 tdim = 3 1221 myax = [1,2,3,4] 1222 if len( y.shape ) == 4: 1223 tdim = 2 1224 myax = [1,2,3] 1225 msqTerm = tf.reduce_mean(squared_difference, axis=myax) 1226 temp1 = feature_extractor(y) 1227 temp2 = feature_extractor(y_pred) 1228 feature_difference = tf.square(temp1-temp2) 1229 featureTerm = tf.reduce_mean(feature_difference, axis=myax) 1230 msqw = 10.0 1231 featw = feature * msqw * msqTerm / featureTerm 1232 mytv = tf.cast( 0.0, 'float32') 1233 if tdim == 3: 1234 for k in range( y_pred.shape[0] ): # BUG not sure why myr fails .... might be old TF version 1235 sqzd = y_pred[k,:,:,:,:] 1236 mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) ) 1237 if tdim == 2: 1238 mytv = tf.reduce_mean( tf.image.total_variation( y_pred ) ) 1239 tvw = tv * msqw * msqTerm / mytv 1240 if verbose : 1241 print( "MSQ: " + str( msqw * msqTerm ) ) 1242 print( "Feat: " + str( featw * featureTerm ) ) 1243 print( "Tv: " + str( mytv * tvw ) ) 1244 wts = [msqw,featw.numpy().mean(),tvw.numpy().mean()] 1245 return wts
Automatically compute weighting coefficients for a combined loss function based on intensity (MSE), perceptual similarity (feature), and total variation (TV).
Parameters
mdl : tf.keras.Model
A trained or untrained model to evaluate predictions on input x
.
feature_extractor : tf.keras.Model A model that extracts intermediate features from the input. Commonly a VGG or ResNet trained on a perceptual task.
x : tf.Tensor Input batch to the model.
y : tf.Tensor
Ground truth target for x
, typically a batch of 2D or 3D volumes.
feature : float, optional Weighting factor for the feature (perceptual) term in the loss. Default is 2.0.
tv : float, optional Weighting factor for the total variation term in the loss. Default is 0.1.
verbose : bool, optional If True, prints each component of the loss and its scaled value.
Returns
list of float
A list of computed weights in the order:
[msq_weight, feature_weight, tv_weight]
Notes
The total loss (to be used during training) can then be constructed as:
`L = msq_weight * MSE + feature_weight * perceptual_loss + tv_weight * TV`
This function is typically used to balance loss terms before training.
1247def auto_weight_loss_seg( mdl, feature_extractor, x, y, feature=2.0, tv=0.1, dice=0.5, verbose=True ): 1248 """ 1249 Automatically compute weighting coefficients for a combined loss function 1250 that includes MSE, perceptual similarity, total variation, and segmentation Dice loss. 1251 1252 Parameters 1253 ---------- 1254 mdl : tf.keras.Model 1255 A segmentation + super-resolution model that outputs both image and label predictions. 1256 1257 feature_extractor : tf.keras.Model 1258 Feature extractor model used to compute perceptual similarity loss. 1259 1260 x : tf.Tensor 1261 Input tensor to the model. 1262 1263 y : tf.Tensor 1264 Target tensor with two channels: [intensity_image, segmentation_label]. 1265 1266 feature : float, optional 1267 Relative weight of the perceptual feature loss term. Default is 2.0. 1268 1269 tv : float, optional 1270 Relative weight of the total variation (TV) term. Default is 0.1. 1271 1272 dice : float, optional 1273 Relative weight of the Dice loss term (for segmentation agreement). Default is 0.5. 1274 1275 verbose : bool, optional 1276 If True, prints the scaled values of each component loss. 1277 1278 Returns 1279 ------- 1280 list of float 1281 A list of loss term weights in the order: 1282 `[msq_weight, feature_weight, tv_weight, dice_weight]` 1283 1284 Notes 1285 ----- 1286 - The input and output tensors must be shaped such that the last axis is 2: 1287 channel 0 is intensity, channel 1 is segmentation. 1288 - This is useful for dual-task networks that predict both high-res images 1289 and associated segmentation masks. 1290 1291 See Also 1292 -------- 1293 binary_dice_loss : Computes Dice loss between predicted and ground-truth masks. 1294 """ 1295 y_pred = mdl( x ) 1296 if len( y.shape ) == 5: 1297 tdim = 3 1298 myax = [1,2,3,4] 1299 if len( y.shape ) == 4: 1300 tdim = 2 1301 myax = [1,2,3] 1302 y_intensity = tf.split( y, 2, axis=tdim+1 )[0] 1303 y_seg = tf.split( y, 2, axis=tdim+1 )[1] 1304 y_intensity_p = tf.split( y_pred, 2, axis=tdim+1 )[0] 1305 y_seg_p = tf.split( y_pred, 2, axis=tdim+1 )[1] 1306 squared_difference = tf.square( y_intensity - y_intensity_p ) 1307 msqTerm = tf.reduce_mean(squared_difference, axis=myax) 1308 temp1 = feature_extractor(y_intensity) 1309 temp2 = feature_extractor(y_intensity_p) 1310 feature_difference = tf.square(temp1-temp2) 1311 featureTerm = tf.reduce_mean(feature_difference, axis=myax) 1312 msqw = 10.0 1313 featw = feature * msqw * msqTerm / featureTerm 1314 mytv = tf.cast( 0.0, 'float32') 1315 if tdim == 3: 1316 for k in range( y_pred.shape[0] ): # BUG not sure why myr fails .... might be old TF version 1317 sqzd = y_pred[k,:,:,:,0] 1318 mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) ) 1319 if tdim == 2: 1320 mytv = tf.reduce_mean( tf.image.total_variation( y_pred[:,:,:,0] ) ) 1321 tvw = tv * msqw * msqTerm / mytv 1322 mydice = binary_dice_loss( y_seg, y_seg_p ) 1323 mydice = tf.reduce_mean( mydice ) 1324 dicew = dice * msqw * msqTerm / mydice 1325 dicewt = np.abs( dicew.numpy().mean() ) 1326 if verbose : 1327 print( "MSQ: " + str( msqw * msqTerm ) ) 1328 print( "Feat: " + str( featw * featureTerm ) ) 1329 print( "Tv: " + str( mytv * tvw ) ) 1330 print( "Dice: " + str( mydice * dicewt ) ) 1331 wts = [msqw,featw.numpy().mean(),tvw.numpy().mean(), dicewt ] 1332 return wts
Automatically compute weighting coefficients for a combined loss function that includes MSE, perceptual similarity, total variation, and segmentation Dice loss.
Parameters
mdl : tf.keras.Model A segmentation + super-resolution model that outputs both image and label predictions.
feature_extractor : tf.keras.Model Feature extractor model used to compute perceptual similarity loss.
x : tf.Tensor Input tensor to the model.
y : tf.Tensor Target tensor with two channels: [intensity_image, segmentation_label].
feature : float, optional Relative weight of the perceptual feature loss term. Default is 2.0.
tv : float, optional Relative weight of the total variation (TV) term. Default is 0.1.
dice : float, optional Relative weight of the Dice loss term (for segmentation agreement). Default is 0.5.
verbose : bool, optional If True, prints the scaled values of each component loss.
Returns
list of float
A list of loss term weights in the order:
[msq_weight, feature_weight, tv_weight, dice_weight]
Notes
- The input and output tensors must be shaped such that the last axis is 2: channel 0 is intensity, channel 1 is segmentation.
- This is useful for dual-task networks that predict both high-res images and associated segmentation masks.
See Also
binary_dice_loss : Computes Dice loss between predicted and ground-truth masks.
1334def numpy_generator( filenames ): 1335 """ 1336 A placeholder or stub for a data generator. 1337 1338 This generator yields a tuple of `None` values once and then stops. It is 1339 likely intended as a template or for debugging purposes where a generator 1340 object is required but no actual data needs to be processed. 1341 1342 Parameters 1343 ---------- 1344 filenames : any 1345 An argument that is not used by the function. 1346 1347 Yields 1348 ------ 1349 tuple 1350 A single tuple `(None, None, None)`. 1351 """ 1352 patchesResam=patchesOrig=patchesUp=None 1353 yield (patchesResam, patchesOrig,patchesUp)
A placeholder or stub for a data generator.
This generator yields a tuple of None
values once and then stops. It is
likely intended as a template or for debugging purposes where a generator
object is required but no actual data needs to be processed.
Parameters
filenames : any An argument that is not used by the function.
Yields
tuple
A single tuple (None, None, None)
.
1355def image_generator( 1356 filenames, 1357 nPatches, 1358 target_patch_size, 1359 target_patch_size_low, 1360 patch_scaler=True, 1361 istest=False, 1362 verbose = False ): 1363 """ 1364 Creates an infinite generator of paired image patches for model training. 1365 1366 This function continuously generates batches of low-resolution (input) and 1367 high-resolution (ground truth) image patches. It is designed to be fed 1368 directly into a Keras `model.fit()` call. 1369 1370 Parameters 1371 ---------- 1372 filenames : list of str 1373 List of file paths to the high-resolution source images. 1374 nPatches : int 1375 The number of patch pairs to generate and yield in each batch. 1376 target_patch_size : tuple or list of int 1377 The dimensions of the high-resolution (ground truth) patches. 1378 target_patch_size_low : tuple or list of int 1379 The dimensions of the low-resolution (input) patches. 1380 patch_scaler : bool, optional 1381 If True, scales patch intensities to [0, 1]. Default is True. 1382 istest : bool, optional 1383 If True, the generator will also yield a third item: a baseline 1384 linearly upsampled version of the low-resolution patch for comparison. 1385 Default is False. 1386 verbose : bool, optional 1387 If True, passes verbosity to the underlying patch generation function. 1388 Default is False. 1389 1390 Yields 1391 ------- 1392 tuple 1393 A tuple of TensorFlow tensors ready for training or evaluation. 1394 - If `istest` is False: `(low_res_batch, high_res_batch)` 1395 - If `istest` is True: `(low_res_batch, high_res_batch, baseline_upsampled_batch)` 1396 1397 See Also 1398 -------- 1399 image_patch_training_data_from_filenames : The function that performs the 1400 underlying patch extraction. 1401 """ 1402 while True: 1403 patchesResam, patchesOrig, patchesUp = image_patch_training_data_from_filenames( 1404 filenames, 1405 target_patch_size = target_patch_size, 1406 target_patch_size_low = target_patch_size_low, 1407 nPatches = nPatches, 1408 istest = istest, 1409 patch_scaler=patch_scaler, 1410 to_tensorflow = True, 1411 verbose = verbose ) 1412 if istest: 1413 yield (patchesResam, patchesOrig,patchesUp) 1414 yield (patchesResam, patchesOrig)
Creates an infinite generator of paired image patches for model training.
This function continuously generates batches of low-resolution (input) and
high-resolution (ground truth) image patches. It is designed to be fed
directly into a Keras model.fit()
call.
Parameters
filenames : list of str List of file paths to the high-resolution source images. nPatches : int The number of patch pairs to generate and yield in each batch. target_patch_size : tuple or list of int The dimensions of the high-resolution (ground truth) patches. target_patch_size_low : tuple or list of int The dimensions of the low-resolution (input) patches. patch_scaler : bool, optional If True, scales patch intensities to [0, 1]. Default is True. istest : bool, optional If True, the generator will also yield a third item: a baseline linearly upsampled version of the low-resolution patch for comparison. Default is False. verbose : bool, optional If True, passes verbosity to the underlying patch generation function. Default is False.
Yields
tuple
A tuple of TensorFlow tensors ready for training or evaluation.
- If istest
is False: (low_res_batch, high_res_batch)
- If istest
is True: (low_res_batch, high_res_batch, baseline_upsampled_batch)
See Also
image_patch_training_data_from_filenames : The function that performs the underlying patch extraction.
1417def seg_generator( 1418 filenames, 1419 nPatches, 1420 target_patch_size, 1421 target_patch_size_low, 1422 patch_scaler=True, 1423 istest=False, 1424 verbose = False ): 1425 """ 1426 Creates an infinite generator of paired image and segmentation patches. 1427 1428 This function continuously generates batches of multi-channel patches, where 1429 one channel is the intensity image and the other is a segmentation mask. 1430 It is designed for training multi-task super-resolution models. 1431 1432 Parameters 1433 ---------- 1434 filenames : list of str 1435 List of file paths to the high-resolution source images. 1436 nPatches : int 1437 The number of patch pairs to generate and yield in each batch. 1438 target_patch_size : tuple or list of int 1439 The dimensions of the high-resolution patches. 1440 target_patch_size_low : tuple or list of int 1441 The dimensions of the low-resolution patches. 1442 patch_scaler : bool, optional 1443 If True, scales the intensity channel of patches to [0, 1]. Default is True. 1444 istest : bool, optional 1445 If True, yields an additional baseline upsampled patch for comparison. 1446 Default is False. 1447 verbose : bool, optional 1448 If True, passes verbosity to the underlying patch generation function. 1449 Default is False. 1450 1451 Yields 1452 ------- 1453 tuple 1454 A tuple of multi-channel TensorFlow tensors. Each tensor has two channels: 1455 Channel 0 contains the intensity image, and Channel 1 contains the 1456 segmentation mask. 1457 1458 See Also 1459 -------- 1460 seg_patch_training_data_from_filenames : The function that performs the 1461 underlying patch extraction. 1462 image_generator : A similar generator for intensity-only data. 1463 """ 1464 while True: 1465 patchesResam, patchesOrig, patchesUp = seg_patch_training_data_from_filenames( 1466 filenames, 1467 target_patch_size = target_patch_size, 1468 target_patch_size_low = target_patch_size_low, 1469 nPatches = nPatches, 1470 istest = istest, 1471 patch_scaler=patch_scaler, 1472 to_tensorflow = True, 1473 verbose = verbose ) 1474 if istest: 1475 yield (patchesResam, patchesOrig,patchesUp) 1476 yield (patchesResam, patchesOrig)
Creates an infinite generator of paired image and segmentation patches.
This function continuously generates batches of multi-channel patches, where one channel is the intensity image and the other is a segmentation mask. It is designed for training multi-task super-resolution models.
Parameters
filenames : list of str List of file paths to the high-resolution source images. nPatches : int The number of patch pairs to generate and yield in each batch. target_patch_size : tuple or list of int The dimensions of the high-resolution patches. target_patch_size_low : tuple or list of int The dimensions of the low-resolution patches. patch_scaler : bool, optional If True, scales the intensity channel of patches to [0, 1]. Default is True. istest : bool, optional If True, yields an additional baseline upsampled patch for comparison. Default is False. verbose : bool, optional If True, passes verbosity to the underlying patch generation function. Default is False.
Yields
tuple A tuple of multi-channel TensorFlow tensors. Each tensor has two channels: Channel 0 contains the intensity image, and Channel 1 contains the segmentation mask.
See Also
seg_patch_training_data_from_filenames : The function that performs the underlying patch extraction. image_generator : A similar generator for intensity-only data.
1479def train( 1480 mdl, 1481 filenames_train, 1482 filenames_test, 1483 target_patch_size, 1484 target_patch_size_low, 1485 output_prefix, 1486 n_test = 8, 1487 learning_rate=5e-5, 1488 feature_layer = 6, 1489 feature = 2, 1490 tv = 0.1, 1491 max_iterations = 1000, 1492 batch_size = 1, 1493 save_all_best = False, 1494 feature_type = 'grader', 1495 check_eval_data_iteration = 20, 1496 verbose = False ): 1497 """ 1498 Orchestrates the training process for a super-resolution model. 1499 1500 This function handles the entire training loop, including setting up data 1501 generators, defining a composite loss function, automatically balancing loss 1502 weights, iteratively training the model, periodically evaluating performance, 1503 and saving the best-performing model weights. 1504 1505 Parameters 1506 ---------- 1507 mdl : tf.keras.Model 1508 The Keras model to be trained. 1509 filenames_train : list of str 1510 List of file paths for the training dataset. 1511 filenames_test : list of str 1512 List of file paths for the validation/testing dataset. 1513 target_patch_size : tuple or list 1514 The dimensions of the high-resolution target patches. 1515 target_patch_size_low : tuple or list 1516 The dimensions of the low-resolution input patches. 1517 output_prefix : str 1518 A prefix for all output files (e.g., model weights, training logs). 1519 n_test : int, optional 1520 The number of validation patches to use for evaluation. Default is 8. 1521 learning_rate : float, optional 1522 The learning rate for the Adam optimizer. Default is 5e-5. 1523 feature_layer : int, optional 1524 The layer index from the feature extractor to use for perceptual loss. 1525 Default is 6. 1526 feature : float, optional 1527 The relative weight of the perceptual (feature) loss term. Default is 2.0. 1528 tv : float, optional 1529 The relative weight of the Total Variation (TV) regularization term. 1530 Default is 0.1. 1531 max_iterations : int, optional 1532 The total number of training iterations to run. Default is 1000. 1533 batch_size : int, optional 1534 The batch size for training. Note: this implementation is optimized for 1535 batch_size=1 and may need adjustment for larger batches. Default is 1. 1536 save_all_best : bool, optional 1537 If True, saves a new model file every time validation loss improves. 1538 If False, overwrites the single best model file. Default is False. 1539 feature_type : str, optional 1540 The type of feature extractor for perceptual loss. Options: 'grader', 1541 'vgg', 'vggrandom'. Default is 'grader'. 1542 check_eval_data_iteration : int, optional 1543 The frequency (in iterations) at which to run validation and save logs. 1544 Default is 20. 1545 verbose : bool, optional 1546 If True, prints detailed progress information. Default is False. 1547 1548 Returns 1549 ------- 1550 pd.DataFrame 1551 A DataFrame containing the training history, with columns for training 1552 loss, validation loss, PSNR, and baseline PSNR over iterations. 1553 """ 1554 colnames = ['train_loss','test_loss','best','eval_psnr','eval_psnr_lin'] 1555 training_path = np.zeros( [ max_iterations, len(colnames) ] ) 1556 training_weights = np.zeros( [1,3] ) 1557 if verbose: 1558 print("begin get feature extractor " + feature_type) 1559 if feature_type == 'grader': 1560 feature_extractor = get_grader_feature_network( feature_layer ) 1561 elif feature_type == 'vggrandom': 1562 with eager_mode(): 1563 feature_extractor = pseudo_3d_vgg_features( target_patch_size, feature_layer, pretrained=False ) 1564 elif feature_type == 'vgg': 1565 with eager_mode(): 1566 feature_extractor = pseudo_3d_vgg_features_unbiased( target_patch_size, feature_layer ) 1567 else: 1568 raise Exception("feature type does not exist") 1569 if verbose: 1570 print("begin train generator") 1571 mydatgen = image_generator( 1572 filenames_train, 1573 nPatches=1, 1574 target_patch_size=target_patch_size, 1575 target_patch_size_low=target_patch_size_low, 1576 istest=False , verbose=False) 1577 if verbose: 1578 print("begin test generator") 1579 mydatgenTest = image_generator( filenames_test, nPatches=1, 1580 target_patch_size=target_patch_size, 1581 target_patch_size_low=target_patch_size_low, 1582 istest=True, verbose=True) 1583 patchesResamTeTf, patchesOrigTeTf, patchesUpTeTf = next( mydatgenTest ) 1584 if len( patchesOrigTeTf.shape ) == 5: 1585 tdim = 3 1586 myax = [1,2,3,4] 1587 if len( patchesOrigTeTf.shape ) == 4: 1588 tdim = 2 1589 myax = [1,2,3] 1590 if verbose: 1591 print("begin train generator #2 at dim: " + str( tdim)) 1592 mydatgenTest = image_generator( filenames_test, nPatches=1, 1593 target_patch_size=target_patch_size, 1594 target_patch_size_low=target_patch_size_low, 1595 istest=True, verbose=True) 1596 patchesResamTeTfB, patchesOrigTeTfB, patchesUpTeTfB = next( mydatgenTest ) 1597 for k in range( n_test - 1 ): 1598 mydatgenTest = image_generator( filenames_test, nPatches=1, 1599 target_patch_size=target_patch_size, 1600 target_patch_size_low=target_patch_size_low, 1601 istest=True, verbose=True) 1602 temp0, temp1, temp2 = next( mydatgenTest ) 1603 patchesResamTeTfB = tf.concat( [patchesResamTeTfB,temp0],axis=0) 1604 patchesOrigTeTfB = tf.concat( [patchesOrigTeTfB,temp1],axis=0) 1605 patchesUpTeTfB = tf.concat( [patchesUpTeTfB,temp2],axis=0) 1606 if verbose: 1607 print("begin auto_weight_loss") 1608 wts_csv = output_prefix + "_training_weights.csv" 1609 if exists( wts_csv ): 1610 wtsdf = pd.read_csv( wts_csv ) 1611 wts = [wtsdf['msq'][0], wtsdf['feat'][0], wtsdf['tv'][0]] 1612 if verbose: 1613 print( "preset weights:" ) 1614 else: 1615 with eager_mode(): 1616 wts = auto_weight_loss( mdl, feature_extractor, patchesResamTeTf, patchesOrigTeTf, 1617 feature=feature, tv=tv ) 1618 for k in range(len(wts)): 1619 training_weights[0,k]=wts[k] 1620 pd.DataFrame(training_weights, columns = ["msq","feat","tv"] ).to_csv( wts_csv ) 1621 if verbose: 1622 print( "automatic weights:" ) 1623 if verbose: 1624 print( wts ) 1625 def my_loss_6( y_true, y_pred, msqwt = wts[0], fw = wts[1], tvwt = wts[2], mybs = batch_size ): 1626 """Composite loss: MSE + Perceptual Loss + Total Variation.""" 1627 squared_difference = tf.square(y_true - y_pred) 1628 if len( y_true.shape ) == 5: 1629 tdim = 3 1630 myax = [1,2,3,4] 1631 if len( y_true.shape ) == 4: 1632 tdim = 2 1633 myax = [1,2,3] 1634 msqTerm = tf.reduce_mean(squared_difference, axis=myax) 1635 temp1 = feature_extractor(y_true) 1636 temp2 = feature_extractor(y_pred) 1637 feature_difference = tf.square(temp1-temp2) 1638 featureTerm = tf.reduce_mean(feature_difference, axis=myax) 1639 loss = msqTerm * msqwt + featureTerm * fw 1640 mytv = tf.cast( 0.0, 'float32') 1641 # mybs = int( y_pred.shape[0] ) --- should work but ... ? 1642 if tdim == 3: 1643 for k in range( mybs ): # BUG not sure why myr fails .... might be old TF version 1644 sqzd = y_pred[k,:,:,:,:] 1645 mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) ) * tvwt 1646 if tdim == 2: 1647 mytv = tf.reduce_mean( tf.image.total_variation( y_pred ) ) * tvwt 1648 return( loss + mytv ) 1649 if verbose: 1650 print("begin model compilation") 1651 opt = tf.keras.optimizers.Adam( learning_rate=learning_rate ) 1652 mdl.compile(optimizer=opt, loss=my_loss_6) 1653 # set up some parameters for tracking performance 1654 bestValLoss=1e12 1655 bestSSIM=0.0 1656 bestQC0 = -1000 1657 bestQC1 = -1000 1658 if verbose: 1659 print( "begin training", flush=True ) 1660 for myrs in range( max_iterations ): 1661 tracker = mdl.fit( mydatgen, epochs=2, steps_per_epoch=4, verbose=1, 1662 validation_data=(patchesResamTeTf,patchesOrigTeTf) ) 1663 training_path[myrs,0]=tracker.history['loss'][0] 1664 training_path[myrs,1]=tracker.history['val_loss'][0] 1665 training_path[myrs,2]=0 1666 print( "ntrain: " + str(myrs) + " loss " + str( tracker.history['loss'][0] ) + ' val-loss ' + str(tracker.history['val_loss'][0]), flush=True ) 1667 if myrs % check_eval_data_iteration == 0: 1668 with tf.device("/cpu:0"): 1669 myofn = output_prefix + "_best_mdl.keras" 1670 if save_all_best: 1671 myofn = output_prefix + "_" + str(myrs)+ "_mdl.keras" 1672 tester = mdl.evaluate( patchesResamTeTfB, patchesOrigTeTfB ) 1673 if ( tester < bestValLoss ): 1674 print("MyIT " + str( myrs ) + " IS BEST!! " + str( tester ) + myofn, flush=True ) 1675 bestValLoss = tester 1676 tf.keras.models.save_model( mdl, myofn ) 1677 training_path[myrs,2]=1 1678 pp = mdl.predict( patchesResamTeTfB, batch_size = 1 ) 1679 myssimSR = tf.image.psnr( pp * 220, patchesOrigTeTfB* 220, max_val=255 ) 1680 myssimSR = tf.reduce_mean( myssimSR ).numpy() 1681 myssimBI = tf.image.psnr( patchesUpTeTfB * 220, patchesOrigTeTfB* 220, max_val=255 ) 1682 myssimBI = tf.reduce_mean( myssimBI ).numpy() 1683 print( myofn + " : " + "PSNR Lin: " + str( myssimBI ) + " SR: " + str( myssimSR ), flush=True ) 1684 training_path[myrs,3]=myssimSR # psnr 1685 training_path[myrs,4]=myssimBI # psnrlin 1686 pd.DataFrame(training_path, columns = colnames ).to_csv( output_prefix + "_training.csv" ) 1687 training_path = pd.DataFrame(training_path, columns = colnames ) 1688 return training_path
Orchestrates the training process for a super-resolution model.
This function handles the entire training loop, including setting up data generators, defining a composite loss function, automatically balancing loss weights, iteratively training the model, periodically evaluating performance, and saving the best-performing model weights.
Parameters
mdl : tf.keras.Model The Keras model to be trained. filenames_train : list of str List of file paths for the training dataset. filenames_test : list of str List of file paths for the validation/testing dataset. target_patch_size : tuple or list The dimensions of the high-resolution target patches. target_patch_size_low : tuple or list The dimensions of the low-resolution input patches. output_prefix : str A prefix for all output files (e.g., model weights, training logs). n_test : int, optional The number of validation patches to use for evaluation. Default is 8. learning_rate : float, optional The learning rate for the Adam optimizer. Default is 5e-5. feature_layer : int, optional The layer index from the feature extractor to use for perceptual loss. Default is 6. feature : float, optional The relative weight of the perceptual (feature) loss term. Default is 2.0. tv : float, optional The relative weight of the Total Variation (TV) regularization term. Default is 0.1. max_iterations : int, optional The total number of training iterations to run. Default is 1000. batch_size : int, optional The batch size for training. Note: this implementation is optimized for batch_size=1 and may need adjustment for larger batches. Default is 1. save_all_best : bool, optional If True, saves a new model file every time validation loss improves. If False, overwrites the single best model file. Default is False. feature_type : str, optional The type of feature extractor for perceptual loss. Options: 'grader', 'vgg', 'vggrandom'. Default is 'grader'. check_eval_data_iteration : int, optional The frequency (in iterations) at which to run validation and save logs. Default is 20. verbose : bool, optional If True, prints detailed progress information. Default is False.
Returns
pd.DataFrame A DataFrame containing the training history, with columns for training loss, validation loss, PSNR, and baseline PSNR over iterations.
1691def binary_dice_loss(y_true, y_pred): 1692 """ 1693 Computes the Dice loss for binary segmentation tasks. 1694 1695 The Dice coefficient is a common metric for comparing the overlap of two samples. 1696 This loss function computes `1 - DiceCoefficient`, making it suitable for 1697 minimization during training. A smoothing factor is added to avoid division 1698 by zero when both the prediction and the ground truth are empty. 1699 1700 Parameters 1701 ---------- 1702 y_true : tf.Tensor 1703 The ground truth binary segmentation mask. Values should be 0 or 1. 1704 y_pred : tf.Tensor 1705 The predicted binary segmentation mask, typically with values in [0, 1] 1706 from a sigmoid activation. 1707 1708 Returns 1709 ------- 1710 tf.Tensor 1711 A scalar tensor representing the Dice loss. The value ranges from -1 (perfect 1712 match) to 0 (no overlap), though it's typically used as `1 - dice_coeff` 1713 or just `-dice_coeff` (as here). 1714 """ 1715 smoothing_factor = 1e-4 1716 K = tf.keras.backend 1717 y_true_f = K.flatten(y_true) 1718 y_pred_f = K.flatten(y_pred) 1719 intersection = K.sum(y_true_f * y_pred_f) 1720 # This is -1 * Dice Similarity Coefficient 1721 return -1 * (2 * intersection + smoothing_factor)/(K.sum(y_true_f) + 1722 K.sum(y_pred_f) + smoothing_factor)
Computes the Dice loss for binary segmentation tasks.
The Dice coefficient is a common metric for comparing the overlap of two samples.
This loss function computes 1 - DiceCoefficient
, making it suitable for
minimization during training. A smoothing factor is added to avoid division
by zero when both the prediction and the ground truth are empty.
Parameters
y_true : tf.Tensor The ground truth binary segmentation mask. Values should be 0 or 1. y_pred : tf.Tensor The predicted binary segmentation mask, typically with values in [0, 1] from a sigmoid activation.
Returns
tf.Tensor
A scalar tensor representing the Dice loss. The value ranges from -1 (perfect
match) to 0 (no overlap), though it's typically used as 1 - dice_coeff
or just -dice_coeff
(as here).
1724def train_seg( 1725 mdl, 1726 filenames_train, 1727 filenames_test, 1728 target_patch_size, 1729 target_patch_size_low, 1730 output_prefix, 1731 n_test = 8, 1732 learning_rate=5e-5, 1733 feature_layer = 6, 1734 feature = 2, 1735 tv = 0.1, 1736 dice = 0.5, 1737 max_iterations = 1000, 1738 batch_size = 1, 1739 save_all_best = False, 1740 feature_type = 'grader', 1741 check_eval_data_iteration = 20, 1742 verbose = False ): 1743 """ 1744 Orchestrates training for a multi-task image and segmentation SR model. 1745 1746 This function extends the `train` function to handle models that predict 1747 both a super-resolved image and a super-resolved segmentation mask. It uses 1748 a four-component composite loss: MSE (for image), a perceptual loss (for 1749 image), Total Variation (for image), and Dice loss (for segmentation). 1750 1751 Parameters 1752 ---------- 1753 mdl : tf.keras.Model 1754 The 2-channel Keras model to be trained. 1755 filenames_train : list of str 1756 List of file paths for the training dataset. 1757 filenames_test : list of str 1758 List of file paths for the validation/testing dataset. 1759 target_patch_size : tuple or list 1760 The dimensions of the high-resolution target patches. 1761 target_patch_size_low : tuple or list 1762 The dimensions of the low-resolution input patches. 1763 output_prefix : str 1764 A prefix for all output files. 1765 n_test : int, optional 1766 Number of validation patches for evaluation. Default is 8. 1767 learning_rate : float, optional 1768 Learning rate for the Adam optimizer. Default is 5e-5. 1769 feature_layer : int, optional 1770 Layer from the feature extractor for perceptual loss. Default is 6. 1771 feature : float, optional 1772 Relative weight of the perceptual loss term. Default is 2.0. 1773 tv : float, optional 1774 Relative weight of the Total Variation regularization term. Default is 0.1. 1775 dice : float, optional 1776 Relative weight of the Dice loss term for the segmentation mask. 1777 Default is 0.5. 1778 max_iterations : int, optional 1779 Total number of training iterations. Default is 1000. 1780 batch_size : int, optional 1781 The batch size for training. Default is 1. 1782 save_all_best : bool, optional 1783 If True, saves all models that improve validation loss. Default is False. 1784 feature_type : str, optional 1785 Type of feature extractor for perceptual loss. Default is 'grader'. 1786 check_eval_data_iteration : int, optional 1787 Frequency (in iterations) for running validation. Default is 20. 1788 verbose : bool, optional 1789 If True, prints detailed progress information. Default is False. 1790 1791 Returns 1792 ------- 1793 pd.DataFrame 1794 A DataFrame containing the training history, including columns for losses 1795 and evaluation metrics like PSNR and Dice score. 1796 1797 See Also 1798 -------- 1799 train : The training function for single-task (intensity-only) models. 1800 """ 1801 colnames = ['train_loss','test_loss','best','eval_psnr','eval_psnr_lin','eval_msq','eval_dice'] 1802 training_path = np.zeros( [ max_iterations, len(colnames) ] ) 1803 training_weights = np.zeros( [1,4] ) 1804 if verbose: 1805 print("begin get feature extractor") 1806 if feature_type == 'grader': 1807 feature_extractor = get_grader_feature_network( feature_layer ) 1808 elif feature_type == 'vggrandom': 1809 feature_extractor = pseudo_3d_vgg_features( target_patch_size, feature_layer, pretrained=False ) 1810 else: 1811 feature_extractor = pseudo_3d_vgg_features_unbiased( target_patch_size, feature_layer ) 1812 if verbose: 1813 print("begin train generator") 1814 mydatgen = seg_generator( 1815 filenames_train, 1816 nPatches=1, 1817 target_patch_size=target_patch_size, 1818 target_patch_size_low=target_patch_size_low, 1819 istest=False , verbose=False) 1820 if verbose: 1821 print("begin test generator") 1822 mydatgenTest = seg_generator( filenames_test, nPatches=1, 1823 target_patch_size=target_patch_size, 1824 target_patch_size_low=target_patch_size_low, 1825 istest=True, verbose=True) 1826 patchesResamTeTf, patchesOrigTeTf, patchesUpTeTf = next( mydatgenTest ) 1827 if len( patchesOrigTeTf.shape ) == 5: 1828 tdim = 3 1829 myax = [1,2,3,4] 1830 if len( patchesOrigTeTf.shape ) == 4: 1831 tdim = 2 1832 myax = [1,2,3] 1833 if verbose: 1834 print("begin train generator #2 at dim: " + str( tdim)) 1835 mydatgenTest = seg_generator( filenames_test, nPatches=1, 1836 target_patch_size=target_patch_size, 1837 target_patch_size_low=target_patch_size_low, 1838 istest=True, verbose=True) 1839 patchesResamTeTfB, patchesOrigTeTfB, patchesUpTeTfB = next( mydatgenTest ) 1840 for k in range( n_test - 1 ): 1841 mydatgenTest = seg_generator( filenames_test, nPatches=1, 1842 target_patch_size=target_patch_size, 1843 target_patch_size_low=target_patch_size_low, 1844 istest=True, verbose=True) 1845 temp0, temp1, temp2 = next( mydatgenTest ) 1846 patchesResamTeTfB = tf.concat( [patchesResamTeTfB,temp0],axis=0) 1847 patchesOrigTeTfB = tf.concat( [patchesOrigTeTfB,temp1],axis=0) 1848 patchesUpTeTfB = tf.concat( [patchesUpTeTfB,temp2],axis=0) 1849 if verbose: 1850 print("begin auto_weight_loss_seg") 1851 wts_csv = output_prefix + "_training_weights.csv" 1852 if exists( wts_csv ): 1853 wtsdf = pd.read_csv( wts_csv ) 1854 wts = [wtsdf['msq'][0], wtsdf['feat'][0], wtsdf['tv'][0], wtsdf['dice'][0]] 1855 if verbose: 1856 print( "preset weights:" ) 1857 else: 1858 wts = auto_weight_loss_seg( mdl, feature_extractor, patchesResamTeTf, patchesOrigTeTf, 1859 feature=feature, tv=tv, dice=dice ) 1860 for k in range(len(wts)): 1861 training_weights[0,k]=wts[k] 1862 pd.DataFrame(training_weights, columns = ["msq","feat","tv","dice"] ).to_csv( wts_csv ) 1863 if verbose: 1864 print( "automatic weights:" ) 1865 if verbose: 1866 print( wts ) 1867 def my_loss_6( y_true, y_pred, msqwt = wts[0], fw = wts[1], tvwt = wts[2], dicewt=wts[3], mybs = batch_size ): 1868 """Composite loss: MSE + Perceptual + TV + Dice.""" 1869 if len( y_true.shape ) == 5: 1870 tdim = 3 1871 myax = [1,2,3,4] 1872 if len( y_true.shape ) == 4: 1873 tdim = 2 1874 myax = [1,2,3] 1875 y_intensity = tf.split( y_true, 2, axis=tdim+1 )[0] 1876 y_seg = tf.split( y_true, 2, axis=tdim+1 )[1] 1877 y_intensity_p = tf.split( y_pred, 2, axis=tdim+1 )[0] 1878 y_seg_p = tf.split( y_pred, 2, axis=tdim+1 )[1] 1879 squared_difference = tf.square(y_intensity - y_intensity_p) 1880 msqTerm = tf.reduce_mean(squared_difference, axis=myax) 1881 temp1 = feature_extractor(y_intensity) 1882 temp2 = feature_extractor(y_intensity_p) 1883 feature_difference = tf.square(temp1-temp2) 1884 featureTerm = tf.reduce_mean(feature_difference, axis=myax) 1885 loss = msqTerm * msqwt + featureTerm * fw 1886 mytv = tf.cast( 0.0, 'float32') 1887 if tdim == 3: 1888 for k in range( mybs ): # BUG not sure why myr fails .... might be old TF version 1889 sqzd = y_pred[k,:,:,:,0] 1890 mytv = mytv + tf.reduce_mean( tf.image.total_variation( sqzd ) ) * tvwt 1891 if tdim == 2: 1892 mytv = tf.reduce_mean( tf.image.total_variation( y_pred[:,:,:,0] ) ) * tvwt 1893 dicer = tf.reduce_mean( dicewt * binary_dice_loss( y_seg, y_seg_p ) ) 1894 return( loss + mytv + dicer ) 1895 if verbose: 1896 print("begin model compilation") 1897 opt = tf.keras.optimizers.Adam( learning_rate=learning_rate ) 1898 mdl.compile(optimizer=opt, loss=my_loss_6) 1899 # set up some parameters for tracking performance 1900 bestValLoss=1e12 1901 bestSSIM=0.0 1902 bestQC0 = -1000 1903 bestQC1 = -1000 1904 if verbose: 1905 print( "begin training", flush=True ) 1906 for myrs in range( max_iterations ): 1907 tracker = mdl.fit( mydatgen, epochs=2, steps_per_epoch=4, verbose=1, 1908 validation_data=(patchesResamTeTf,patchesOrigTeTf) ) 1909 training_path[myrs,0]=tracker.history['loss'][0] 1910 training_path[myrs,1]=tracker.history['val_loss'][0] 1911 training_path[myrs,2]=0 1912 print( "ntrain: " + str(myrs) + " loss " + str( tracker.history['loss'][0] ) + ' val-loss ' + str(tracker.history['val_loss'][0]), flush=True ) 1913 if myrs % check_eval_data_iteration == 0: 1914 with tf.device("/cpu:0"): 1915 myofn = output_prefix + "_best_mdl.keras" 1916 if save_all_best: 1917 myofn = output_prefix + "_" + str(myrs)+ "_mdl.keras" 1918 tester = mdl.evaluate( patchesResamTeTfB, patchesOrigTeTfB ) 1919 if ( tester < bestValLoss ): 1920 print("MyIT " + str( myrs ) + " IS BEST!! " + str( tester ) + myofn, flush=True ) 1921 bestValLoss = tester 1922 tf.keras.models.save_model( mdl, myofn ) 1923 training_path[myrs,2]=1 1924 pp = mdl.predict( patchesResamTeTfB, batch_size = 1 ) 1925 pp = tf.split( pp, 2, axis=tdim+1 ) 1926 y_orig = tf.split( patchesOrigTeTfB, 2, axis=tdim+1 ) 1927 y_up = tf.split( patchesUpTeTfB, 2, axis=tdim+1 ) 1928 myssimSR = tf.image.psnr( pp[0] * 220, y_orig[0]* 220, max_val=255 ) 1929 myssimSR = tf.reduce_mean( myssimSR ).numpy() 1930 myssimBI = tf.image.psnr( y_up[0] * 220, y_orig[0]* 220, max_val=255 ) 1931 myssimBI = tf.reduce_mean( myssimBI ).numpy() 1932 squared_difference = tf.square(y_orig[0] - pp[0]) 1933 msqTerm = tf.reduce_mean(squared_difference).numpy() 1934 dicer = binary_dice_loss( y_orig[1], pp[1] ) 1935 dicer = tf.reduce_mean( dicer ).numpy() 1936 print( myofn + " : " + "PSNR Lin: " + str( myssimBI ) + " SR: " + str( myssimSR ) + " MSQ: " + str(msqTerm) + " DICE: " + str(dicer), flush=True ) 1937 training_path[myrs,3]=myssimSR # psnr 1938 training_path[myrs,4]=myssimBI # psnrlin 1939 training_path[myrs,5]=msqTerm # msq 1940 training_path[myrs,6]=dicer # dice 1941 pd.DataFrame(training_path, columns = colnames ).to_csv( output_prefix + "_training.csv" ) 1942 training_path = pd.DataFrame(training_path, columns = colnames ) 1943 return training_path
Orchestrates training for a multi-task image and segmentation SR model.
This function extends the train
function to handle models that predict
both a super-resolved image and a super-resolved segmentation mask. It uses
a four-component composite loss: MSE (for image), a perceptual loss (for
image), Total Variation (for image), and Dice loss (for segmentation).
Parameters
mdl : tf.keras.Model The 2-channel Keras model to be trained. filenames_train : list of str List of file paths for the training dataset. filenames_test : list of str List of file paths for the validation/testing dataset. target_patch_size : tuple or list The dimensions of the high-resolution target patches. target_patch_size_low : tuple or list The dimensions of the low-resolution input patches. output_prefix : str A prefix for all output files. n_test : int, optional Number of validation patches for evaluation. Default is 8. learning_rate : float, optional Learning rate for the Adam optimizer. Default is 5e-5. feature_layer : int, optional Layer from the feature extractor for perceptual loss. Default is 6. feature : float, optional Relative weight of the perceptual loss term. Default is 2.0. tv : float, optional Relative weight of the Total Variation regularization term. Default is 0.1. dice : float, optional Relative weight of the Dice loss term for the segmentation mask. Default is 0.5. max_iterations : int, optional Total number of training iterations. Default is 1000. batch_size : int, optional The batch size for training. Default is 1. save_all_best : bool, optional If True, saves all models that improve validation loss. Default is False. feature_type : str, optional Type of feature extractor for perceptual loss. Default is 'grader'. check_eval_data_iteration : int, optional Frequency (in iterations) for running validation. Default is 20. verbose : bool, optional If True, prints detailed progress information. Default is False.
Returns
pd.DataFrame A DataFrame containing the training history, including columns for losses and evaluation metrics like PSNR and Dice score.
See Also
train : The training function for single-task (intensity-only) models.
1946def read_srmodel(srfilename, custom_objects=None): 1947 """ 1948 Load a super-resolution model (h5, .keras, or SavedModel format), 1949 and determine its upsampling factor. 1950 1951 Parameters 1952 ---------- 1953 srfilename : str 1954 Path to the model file (.h5, .keras, or a SavedModel folder). 1955 custom_objects : dict, optional 1956 Dictionary of custom objects used in the model (e.g. {'TFOpLambda': tf.keras.layers.Lambda(...)}) 1957 1958 Returns 1959 ------- 1960 model : tf.keras.Model 1961 The loaded model. 1962 upsampling_factor : list of int 1963 List describing the upsampling factor: 1964 - For 3D input: [x_up, y_up, z_up, channels] 1965 - For 2D input: [x_up, y_up, channels] 1966 1967 Example 1968 ------- 1969 >>> mdl, up = read_srmodel("mymodel.keras") 1970 >>> mdl, up = read_srmodel("my_weights.h5", custom_objects={"TFOpLambda": tf.keras.layers.Lambda(tf.identity)}) 1971 """ 1972 1973 # Expand path and detect format 1974 srfilename = os.path.expanduser(srfilename) 1975 ext = os.path.splitext(srfilename)[1].lower() 1976 1977 if os.path.isdir(srfilename): 1978 # SavedModel directory 1979 model = tf.keras.models.load_model(srfilename, custom_objects=custom_objects, compile=False) 1980 elif ext in ['.h5', '.keras']: 1981 model = tf.keras.models.load_model(srfilename, custom_objects=custom_objects, compile=False) 1982 else: 1983 raise ValueError(f"Unsupported model format: {ext}") 1984 1985 # Determine channel index 1986 input_shape = model.input_shape 1987 if isinstance(input_shape, list): 1988 input_shape = input_shape[0] 1989 chanindex = 3 if len(input_shape) == 4 else 4 1990 nchan = int(input_shape[chanindex]) 1991 1992 # Run dummy input to compute upsampling factor 1993 try: 1994 if len(input_shape) == 5: # 3D 1995 dummy_input = np.zeros([1, 8, 8, 8, nchan]) 1996 else: # 2D 1997 dummy_input = np.zeros([1, 8, 8, nchan]) 1998 1999 # Handle named inputs if necessary 2000 try: 2001 output = model(dummy_input) 2002 except Exception: 2003 output = model({model.input_names[0]: dummy_input}) 2004 2005 outshp = output.shape 2006 if len(input_shape) == 5: 2007 return model, [int(outshp[1]/8), int(outshp[2]/8), int(outshp[3]/8), nchan] 2008 else: 2009 return model, [int(outshp[1]/8), int(outshp[2]/8), nchan] 2010 2011 except Exception as e: 2012 raise RuntimeError(f"Could not infer upsampling factor. Error: {e}")
Load a super-resolution model (h5, .keras, or SavedModel format), and determine its upsampling factor.
Parameters
srfilename : str Path to the model file (.h5, .keras, or a SavedModel folder). custom_objects : dict, optional Dictionary of custom objects used in the model (e.g. {'TFOpLambda': tf.keras.layers.Lambda(...)})
Returns
model : tf.keras.Model The loaded model. upsampling_factor : list of int List describing the upsampling factor: - For 3D input: [x_up, y_up, z_up, channels] - For 2D input: [x_up, y_up, channels]
Example
>>> mdl, up = read_srmodel("mymodel.keras")
>>> mdl, up = read_srmodel("my_weights.h5", custom_objects={"TFOpLambda": tf.keras.layers.Lambda(tf.identity)})
2015def simulate_image( shaper=[32,32,32], n_levels=10, multiply=False ): 2016 """ 2017 generate an image of given shape and number of levels 2018 2019 Arguments 2020 --------- 2021 shaper : [x,y,z] or [x,y] 2022 2023 n_levels : int 2024 2025 multiply : boolean 2026 2027 Returns 2028 ------- 2029 2030 ants.image 2031 2032 """ 2033 img = ants.from_numpy( np.random.normal( 0, 1.0, size=shaper ) ) * 0 2034 for k in range(n_levels): 2035 temp = ants.from_numpy( np.random.normal( 0, 1.0, size=shaper ) ) 2036 temp = ants.smooth_image( temp, n_levels ) 2037 temp = ants.threshold_image( temp, "Otsu", 1 ) 2038 if multiply: 2039 temp = temp * k 2040 img = img + temp 2041 return img
generate an image of given shape and number of levels
Arguments
shaper : [x,y,z] or [x,y]
n_levels : int
multiply : boolean
Returns
ants.image
2044def optimize_upsampling_shape( spacing, modality='T1', roundit=False, verbose=False ): 2045 """ 2046 Compute the optimal upsampling shape string (e.g., '2x2x2') based on image voxel spacing 2047 and imaging modality. This output is used to select an appropriate pretrained 2048 super-resolution model filename. 2049 2050 Parameters 2051 ---------- 2052 spacing : sequence of float 2053 Voxel spacing (physical size per voxel in mm) from the input image. 2054 Typically obtained from `ants.get_spacing(image)`. 2055 2056 modality : str, optional 2057 Imaging modality. Affects resolution thresholds: 2058 - 'T1' : anatomical MRI (default minimum spacing: 0.35 mm) 2059 - 'DTI' : diffusion MRI (default minimum spacing: 1.0 mm) 2060 - 'NM' : nuclear medicine (e.g., PET/SPECT, minimum spacing: 0.25 mm) 2061 2062 roundit : bool, optional 2063 If True, uses rounded integer ratios for the upsampling shape. 2064 Otherwise, uses floor division with constraints. 2065 2066 verbose : bool, optional 2067 If True, prints detailed internal values and logic. 2068 2069 Returns 2070 ------- 2071 str 2072 Optimal upsampling shape string in the form 'AxBxC', 2073 e.g., '2x2x2', '4x4x2'. 2074 2075 Notes 2076 ----- 2077 - The function prevents upsampling ratios that would result in '1x1x1' 2078 by defaulting to '2x2x2'. 2079 - It also avoids uncommon ratios like '5' by rounding to the nearest valid option. 2080 - The returned string is commonly used to populate a model filename template: 2081 2082 Example: 2083 >>> bestup = optimize_upsampling_shape(ants.get_spacing(t1_img), modality='T1') 2084 >>> model = re.sub('bestup', bestup, 'siq_smallshort_train_bestup_1chan.keras') 2085 """ 2086 minspc = min( list( spacing ) ) 2087 maxspc = max( list( spacing ) ) 2088 ratio = maxspc/minspc 2089 if ratio == 1.0: 2090 ratio = 0.5 2091 roundratio = np.round( ratio ) 2092 tarshaperaw = [] 2093 tarshape = [] 2094 tarshaperound = [] 2095 for k in range( len( spacing ) ): 2096 locrat = spacing[k]/minspc 2097 newspc = spacing[k] * roundratio 2098 tarshaperaw.append( locrat ) 2099 if modality == "NM": 2100 if verbose: 2101 print("Using minspacing: 0.25") 2102 if newspc < 0.25 : 2103 locrat = spacing[k]/0.25 2104 elif modality == "DTI": 2105 if verbose: 2106 print("Using minspacing: 1.0") 2107 if newspc < 1.0 : 2108 locrat = spacing[k]/1.0 2109 else: # assume T1 2110 if verbose: 2111 print("Using minspacing: 0.35") 2112 if newspc < 0.35 : 2113 locrat = spacing[k]/0.35 2114 myint = int( locrat ) 2115 if ( myint == 0 ): 2116 myint = 1 2117 if myint == 5: 2118 myint = 4 2119 if ( myint > 6 ): 2120 myint = 6 2121 tarshape.append( str( myint ) ) 2122 tarshaperound.append( str( int(np.round( locrat )) ) ) 2123 if verbose: 2124 print("before emendation:") 2125 print( tarshaperaw ) 2126 print( tarshaperound ) 2127 print( tarshape ) 2128 allone = True 2129 if roundit: 2130 tarshape = tarshaperound 2131 for k in range( len( tarshape ) ): 2132 if tarshape[k] != "1": 2133 allone=False 2134 if allone: 2135 tarshape = ["2","2","2"] # default 2136 return "x".join(tarshape)
Compute the optimal upsampling shape string (e.g., '2x2x2') based on image voxel spacing and imaging modality. This output is used to select an appropriate pretrained super-resolution model filename.
Parameters
spacing : sequence of float
Voxel spacing (physical size per voxel in mm) from the input image.
Typically obtained from ants.get_spacing(image)
.
modality : str, optional Imaging modality. Affects resolution thresholds: - 'T1' : anatomical MRI (default minimum spacing: 0.35 mm) - 'DTI' : diffusion MRI (default minimum spacing: 1.0 mm) - 'NM' : nuclear medicine (e.g., PET/SPECT, minimum spacing: 0.25 mm)
roundit : bool, optional If True, uses rounded integer ratios for the upsampling shape. Otherwise, uses floor division with constraints.
verbose : bool, optional If True, prints detailed internal values and logic.
Returns
str Optimal upsampling shape string in the form 'AxBxC', e.g., '2x2x2', '4x4x2'.
Notes
- The function prevents upsampling ratios that would result in '1x1x1' by defaulting to '2x2x2'.
- It also avoids uncommon ratios like '5' by rounding to the nearest valid option.
The returned string is commonly used to populate a model filename template:
Example:
bestup = optimize_upsampling_shape(ants.get_spacing(t1_img), modality='T1') model = re.sub('bestup', bestup, 'siq_smallshort_train_bestup_1chan.keras')
2138def compare_models( model_filenames, img, n_classes=3, 2139 poly_order='hist', 2140 identifier=None, noise_sd=0.1,verbose=False ): 2141 """ 2142 Evaluates and compares the performance of multiple super-resolution models on a given image. 2143 2144 This function provides a standardized way to benchmark SR models. For each model, 2145 it performs the following steps: 2146 1. Loads the model and determines its upsampling factor. 2147 2. Downsamples the high-resolution input image (`img`) to create a low-resolution 2148 input, simulating a real-world scenario. 2149 3. Adds Gaussian noise to the low-resolution input to test for robustness. 2150 4. Runs inference using the model to generate a super-resolved output. 2151 5. Generates a baseline output by upsampling the low-res input with linear interpolation. 2152 6. Calculates PSNR and SSIM metrics comparing both the model's output and the 2153 baseline against the original high-resolution image. 2154 7. If a dual-channel (image + segmentation) model is detected, it also calculates 2155 Dice scores for segmentation performance. 2156 8. Aggregates all results into a pandas DataFrame for easy comparison. 2157 2158 Parameters 2159 ---------- 2160 model_filenames : list of str 2161 A list of file paths to the Keras models (.h5, .keras) to be compared. 2162 img : ants.ANTsImage 2163 The high-resolution ground truth image. This image will be downsampled to 2164 create the input for the models. 2165 n_classes : int, optional 2166 The number of classes for Otsu's thresholding when auto-generating a 2167 segmentation for evaluating dual-channel models. Default is 3. 2168 poly_order : str or int, optional 2169 Method for intensity matching between the SR output and the reference. 2170 Options: 'hist' for histogram matching (default), an integer for 2171 polynomial regression, or None to disable. 2172 identifier : str, optional 2173 A custom identifier for the output DataFrame. If None, it is inferred 2174 from the model filename. Default is None. 2175 noise_sd : float, optional 2176 Standard deviation of the additive Gaussian noise applied to the 2177 downsampled image before inference. Default is 0.1. 2178 verbose : bool, optional 2179 If True, prints detailed progress and intermediate values. Default is False. 2180 2181 Returns 2182 ------- 2183 pd.DataFrame 2184 A DataFrame where each row corresponds to a model. Columns contain evaluation 2185 metrics (PSNR.SR, SSIM.SR, DICE.SR), baseline metrics (PSNR.LIN, SSIM.LIN, 2186 DICE.NN), and metadata. 2187 2188 Notes 2189 ----- 2190 When evaluating a 2-channel (segmentation) model, the primary metric for the 2191 segmentation task is the Dice score (`DICE.SR`). The intensity metrics (PSNR, SSIM) 2192 are still computed on the first channel. 2193 """ 2194 padding=4 2195 mydf = pd.DataFrame() 2196 for k in range( len( model_filenames ) ): 2197 srmdl, upshape = read_srmodel( model_filenames[k] ) 2198 if verbose: 2199 print( model_filenames[k] ) 2200 print( upshape ) 2201 tarshape = [] 2202 inspc = ants.get_spacing(img) 2203 for j in range(len(img.shape)): 2204 tarshape.append( float(upshape[j]) * inspc[j] ) 2205 # uses linear interp 2206 dimg=ants.resample_image( img, tarshape, use_voxels=False, interp_type=0 ) 2207 dimg = ants.add_noise_to_image( dimg,'additivegaussian', [0,noise_sd] ) 2208 import math 2209 dicesr=math.nan 2210 dicenn=math.nan 2211 if upshape[3] == 2: 2212 seghigh = ants.threshold_image( img,"Otsu",n_classes) 2213 seglow = ants.resample_image( seghigh, tarshape, use_voxels=False, interp_type=1 ) 2214 dimgup=inference( dimg, srmdl, segmentation = seglow, poly_order=poly_order, verbose=verbose ) 2215 dimgupseg = dimgup['super_resolution_segmentation'] 2216 dimgup = dimgup['super_resolution'] 2217 segblock = ants.resample_image_to_target( seghigh, dimgupseg, interp_type='nearestNeighbor' ) 2218 segimgnn = ants.resample_image_to_target( seglow, dimgupseg, interp_type='nearestNeighbor' ) 2219 segblock[ dimgupseg == 0 ] = 0 2220 segimgnn[ dimgupseg == 0 ] = 0 2221 dicenn = ants.label_overlap_measures(segblock, segimgnn)['MeanOverlap'][0] 2222 dicesr = ants.label_overlap_measures(segblock, dimgupseg)['MeanOverlap'][0] 2223 else: 2224 dimgup=inference( dimg, srmdl, poly_order=poly_order, verbose=verbose ) 2225 dimglin = ants.resample_image_to_target( dimg, dimgup, interp_type='linear' ) 2226 imgblock = ants.resample_image_to_target( img, dimgup, interp_type='linear' ) 2227 dimgup[ imgblock == 0.0 ]=0.0 2228 dimglin[ imgblock == 0.0 ]=0.0 2229 padder = [] 2230 dimwarning=False 2231 for jj in range(img.dimension): 2232 padder.append( padding ) 2233 if img.shape[jj] != imgblock.shape[jj]: 2234 dimwarning=True 2235 if dimwarning: 2236 print("NOTE: dimensions of downsampled to upsampled image do not match!!!") 2237 print("we force them to match but this suggests results may not be reliable.") 2238 temp = os.path.basename( model_filenames[k] ) 2239 temp = re.sub( "siq_default_sisr_", "", temp ) 2240 temp = re.sub( "_best_mdl.keras", "", temp ) 2241 temp = re.sub( "_best_mdl.h5", "", temp ) 2242 if verbose and dimwarning: 2243 print( "original img shape" ) 2244 print( img.shape ) 2245 print( "resampled img shape" ) 2246 print( imgblock.shape ) 2247 a=[] 2248 imgshape = [] 2249 for aa in range(len(upshape)): 2250 a.append( str(upshape[aa]) ) 2251 if aa < len(imgblock.shape): 2252 imgshape.append( str( imgblock.shape[aa] ) ) 2253 if identifier is None: 2254 identifier=temp 2255 mydict = { 2256 "identifier":identifier, 2257 "imgshape":"x".join(imgshape), 2258 "mdl": temp, 2259 "mdlshape":"x".join(a), 2260 "PSNR.LIN": antspynet.psnr( imgblock, dimglin ), 2261 "PSNR.SR": antspynet.psnr( imgblock, dimgup ), 2262 "SSIM.LIN": antspynet.ssim( imgblock, dimglin ), 2263 "SSIM.SR": antspynet.ssim( imgblock, dimgup ), 2264 "DICE.NN": dicenn, 2265 "DICE.SR": dicesr, 2266 "dimwarning": dimwarning } 2267 if verbose: 2268 print( mydict ) 2269 temp = pd.DataFrame.from_records( [mydict], index=[0] ) 2270 mydf = pd.concat( [mydf,temp], axis=0 ) 2271 # end loop 2272 return mydf
Evaluates and compares the performance of multiple super-resolution models on a given image.
This function provides a standardized way to benchmark SR models. For each model, it performs the following steps:
- Loads the model and determines its upsampling factor.
- Downsamples the high-resolution input image (
img
) to create a low-resolution input, simulating a real-world scenario. - Adds Gaussian noise to the low-resolution input to test for robustness.
- Runs inference using the model to generate a super-resolved output.
- Generates a baseline output by upsampling the low-res input with linear interpolation.
- Calculates PSNR and SSIM metrics comparing both the model's output and the baseline against the original high-resolution image.
- If a dual-channel (image + segmentation) model is detected, it also calculates Dice scores for segmentation performance.
- Aggregates all results into a pandas DataFrame for easy comparison.
Parameters
model_filenames : list of str A list of file paths to the Keras models (.h5, .keras) to be compared. img : ants.ANTsImage The high-resolution ground truth image. This image will be downsampled to create the input for the models. n_classes : int, optional The number of classes for Otsu's thresholding when auto-generating a segmentation for evaluating dual-channel models. Default is 3. poly_order : str or int, optional Method for intensity matching between the SR output and the reference. Options: 'hist' for histogram matching (default), an integer for polynomial regression, or None to disable. identifier : str, optional A custom identifier for the output DataFrame. If None, it is inferred from the model filename. Default is None. noise_sd : float, optional Standard deviation of the additive Gaussian noise applied to the downsampled image before inference. Default is 0.1. verbose : bool, optional If True, prints detailed progress and intermediate values. Default is False.
Returns
pd.DataFrame A DataFrame where each row corresponds to a model. Columns contain evaluation metrics (PSNR.SR, SSIM.SR, DICE.SR), baseline metrics (PSNR.LIN, SSIM.LIN, DICE.NN), and metadata.
Notes
When evaluating a 2-channel (segmentation) model, the primary metric for the
segmentation task is the Dice score (DICE.SR
). The intensity metrics (PSNR, SSIM)
are still computed on the first channel.
2277def region_wise_super_resolution(image, mask, super_res_model, dilation_amount=4, verbose=False): 2278 """ 2279 Apply super-resolution model to each labeled region in the mask independently. 2280 2281 Arguments 2282 --------- 2283 image : ANTsImage 2284 Input image. 2285 2286 mask : ANTsImage 2287 Integer-labeled segmentation mask with non-zero regions to upsample. 2288 2289 super_res_model : tf.keras.Model 2290 Trained super-resolution model. 2291 2292 dilation_amount : int 2293 Number of morphological dilations applied to each label region before cropping. 2294 2295 verbose : bool 2296 If True, print detailed status. 2297 2298 Returns 2299 ------- 2300 ANTsImage : Full-size super-resolved image with per-label inference and stitching. 2301 """ 2302 import ants 2303 import numpy as np 2304 from antspynet import apply_super_resolution_model_to_image 2305 2306 upFactor = [] 2307 input_shape = super_res_model.inputs[0].shape 2308 test_shape = [1, 8, 8, 1] if len(input_shape) == 4 else [1, 8, 8, 8, 1] 2309 test_input = np.zeros(test_shape, dtype=np.float32) 2310 test_output = super_res_model(test_input) 2311 2312 for k in range(len(test_shape) - 2): # ignore batch + channel 2313 upFactor.append(int(test_output.shape[k + 1] / test_input.shape[k + 1])) 2314 2315 original_size = mask.shape # e.g., (x, y, z) 2316 new_size = tuple(int(s * f) for s, f in zip(original_size, upFactor)) 2317 upsampled_mask = ants.resample_image(mask, new_size, use_voxels=True, interp_type=1) 2318 upsampled_image = ants.resample_image(image, new_size, use_voxels=True, interp_type=0) 2319 2320 unique_labels = list(np.unique(upsampled_mask.numpy())) 2321 if 0 in unique_labels: 2322 unique_labels.remove(0) 2323 2324 outimg = ants.image_clone(upsampled_image) 2325 2326 for lab in unique_labels: 2327 if verbose: 2328 print(f"Processing label: {lab}") 2329 regionmask = ants.threshold_image(mask, lab, lab).iMath("MD", dilation_amount) 2330 cropped = ants.crop_image(image, regionmask) 2331 if cropped.shape[0] == 0: 2332 continue 2333 subimgsr = apply_super_resolution_model_to_image( 2334 cropped, super_res_model, target_range=[0, 1], verbose=verbose 2335 ) 2336 stitched = ants.decrop_image(subimgsr, outimg) 2337 outimg[upsampled_mask == lab] = stitched[upsampled_mask == lab] 2338 2339 return outimg
Apply super-resolution model to each labeled region in the mask independently.
Arguments
image : ANTsImage Input image.
mask : ANTsImage Integer-labeled segmentation mask with non-zero regions to upsample.
super_res_model : tf.keras.Model Trained super-resolution model.
dilation_amount : int Number of morphological dilations applied to each label region before cropping.
verbose : bool If True, print detailed status.
Returns
ANTsImage : Full-size super-resolved image with per-label inference and stitching.
2342def region_wise_super_resolution_blended(image, mask, super_res_model, dilation_amount=4, verbose=False): 2343 """ 2344 Apply super-resolution model to labeled regions with smooth blending to minimize stitching artifacts. 2345 2346 This version uses a weighted-averaging scheme based on distance transforms 2347 to create seamless transitions between super-resolved regions and the background. 2348 2349 Arguments 2350 --------- 2351 image : ANTsImage 2352 Input low-resolution image. 2353 2354 mask : ANTsImage 2355 Integer-labeled segmentation mask. 2356 2357 super_res_model : tf.keras.Model 2358 Trained super-resolution model. 2359 2360 dilation_amount : int 2361 Number of morphological dilations applied to each label region before cropping. 2362 This provides context to the SR model. 2363 2364 verbose : bool 2365 If True, print detailed status. 2366 2367 Returns 2368 ------- 2369 ANTsImage : Full-size, super-resolved image with seamless blending. 2370 """ 2371 import ants 2372 import numpy as np 2373 from antspynet import apply_super_resolution_model_to_image 2374 epsilon32 = np.finfo(np.float32).eps 2375 normalize_weight_maps = True # Default behavior to normalize weight maps 2376 # --- Step 1: Determine upsampling factor and prepare initial images --- 2377 upFactor = [] 2378 input_shape = super_res_model.inputs[0].shape 2379 test_shape = [1, 8, 8, 1] if len(input_shape) == 4 else [1, 8, 8, 8, 1] 2380 test_input = np.zeros(test_shape, dtype=np.float32) 2381 test_output = super_res_model(test_input) 2382 for k in range(len(test_shape) - 2): 2383 upFactor.append(int(test_output.shape[k + 1] / test_input.shape[k + 1])) 2384 2385 original_size = image.shape 2386 new_size = tuple(int(s * f) for s, f in zip(original_size, upFactor)) 2387 2388 # The initial upsampled image will serve as our background 2389 background_sr_image = ants.resample_image(image, new_size, use_voxels=True, interp_type=0) 2390 2391 # --- Step 2: Initialize accumulator and weight sum canvases --- 2392 # These must be float type for accumulation 2393 accumulator = ants.image_clone(background_sr_image).astype('float32') * 0.0 2394 weight_sum = ants.image_clone(accumulator) 2395 2396 unique_labels = [l for l in np.unique(mask.numpy()) if l != 0] 2397 2398 for lab in unique_labels: 2399 if verbose: 2400 print(f"Blending label: {lab}") 2401 2402 # --- Step 3: Super-resolve a dilated patch (provides context to the model) --- 2403 region_mask_dilated = ants.threshold_image(mask, lab, lab).iMath("MD", dilation_amount) 2404 cropped_lowres = ants.crop_image(image, region_mask_dilated) 2405 if cropped_lowres.shape[0] == 0: 2406 continue 2407 2408 # Apply the model to the cropped low-res patch 2409 sr_patch = apply_super_resolution_model_to_image( 2410 cropped_lowres, super_res_model, target_range=[0, 1] 2411 ) 2412 2413 # Place the super-resolved patch back onto a full-sized canvas 2414 sr_patch_full_size = ants.decrop_image(sr_patch, accumulator) 2415 2416 # --- Step 4: Create a smooth weight map for this region --- 2417 # We use the *non-dilated* mask for the weight map to ensure a sharp focus on the target region. 2418 region_mask_original = ants.threshold_image(mask, lab, lab) 2419 2420 # Resample the original region mask to the high-res grid 2421 weight_map = ants.resample_image(region_mask_original, new_size, use_voxels=True, interp_type=0) 2422 weight_map = ants.smooth_image(weight_map, sigma=2.0, 2423 sigma_in_physical_coordinates=False) 2424 if normalize_weight_maps: 2425 weight_map = ants.iMath(weight_map, "Normalize") 2426 # --- Step 5: Accumulate the weighted values and the weights themselves --- 2427 accumulator += sr_patch_full_size * weight_map 2428 weight_sum += weight_map 2429 2430 # --- Step 6: Final Combination --- 2431 # Normalize the accumulator by the total weight at each pixel 2432 weight_sum_np = weight_sum.numpy() 2433 accumulator_np = accumulator.numpy() 2434 2435 # Create a mask of pixels where blending occurred 2436 blended_mask = weight_sum_np > 0.0 # Use a small epsilon for float safety 2437 2438 # Start with the original upsampled image as the base 2439 final_image_np = background_sr_image.numpy() 2440 2441 # Perform the weighted average only where weights are non-zero 2442 final_image_np[blended_mask] = accumulator_np[blended_mask] / weight_sum_np[blended_mask] 2443 2444 # Re-insert any non-blended background regions that were processed 2445 # This handles cases where regions overlap; the weighted average takes care of it. 2446 2447 return ants.from_numpy(final_image_np, origin=background_sr_image.origin, 2448 spacing=background_sr_image.spacing, direction=background_sr_image.direction)
Apply super-resolution model to labeled regions with smooth blending to minimize stitching artifacts.
This version uses a weighted-averaging scheme based on distance transforms to create seamless transitions between super-resolved regions and the background.
Arguments
image : ANTsImage Input low-resolution image.
mask : ANTsImage Integer-labeled segmentation mask.
super_res_model : tf.keras.Model Trained super-resolution model.
dilation_amount : int Number of morphological dilations applied to each label region before cropping. This provides context to the SR model.
verbose : bool If True, print detailed status.
Returns
ANTsImage : Full-size, super-resolved image with seamless blending.
2451def inference( 2452 image, 2453 mdl, 2454 truncation=None, 2455 segmentation=None, 2456 target_range=[1, 0], 2457 poly_order='hist', 2458 dilation_amount=0, 2459 verbose=False): 2460 """ 2461 Perform super-resolution inference on an input image, optionally guided by segmentation. 2462 2463 This function uses a trained deep learning model to enhance the resolution of a medical image. 2464 It optionally applies label-wise inference if a segmentation mask is provided. 2465 2466 Parameters 2467 ---------- 2468 image : ants.ANTsImage 2469 Input image to be super-resolved. 2470 2471 mdl : keras.Model 2472 Trained super-resolution model, typically from ANTsPyNet. 2473 2474 truncation : tuple or list of float, optional 2475 Percentile values (e.g., [0.01, 0.99]) for intensity truncation before model input. 2476 If None, no truncation is applied. 2477 2478 segmentation : ants.ANTsImage, optional 2479 A labeled segmentation mask. If provided, super-resolution is performed per label 2480 using `region_wise_super_resolution` or `super_resolution_segmentation_per_label`. 2481 2482 target_range : list of float 2483 Intensity range used for scaling the input before applying the model. 2484 Default is [1, 0] (internal default for `apply_super_resolution_model_to_image`). 2485 2486 poly_order : int, str or None 2487 Determines how to match intensity between the super-resolved image and the original. 2488 Options: 2489 - 'hist' : use histogram matching 2490 - int >= 1 : perform polynomial regression of this order 2491 - None : no intensity adjustment 2492 2493 dilation_amount : int 2494 Number of dilation steps applied to each segmentation label during 2495 region-based super-resolution (if segmentation is provided). 2496 2497 verbose : bool 2498 If True, print progress and status messages. 2499 2500 Returns 2501 ------- 2502 ANTsImage or dict 2503 - If `segmentation` is None, returns a single ANTsImage (super-resolved image). 2504 - If `segmentation` is provided, returns a dictionary with: 2505 - 'super_resolution': ANTsImage 2506 - other entries may include label-wise results or metadata. 2507 2508 Examples 2509 -------- 2510 >>> import ants 2511 >>> import antspynet 2512 >>> from siq import inference 2513 >>> img = ants.image_read("lowres.nii.gz") 2514 >>> model = antspynet.get_pretrained_network("dbpn", target_suffix="T1") 2515 >>> srimg = inference(img, model, truncation=[0.01, 0.99], verbose=True) 2516 2517 >>> seg = ants.image_read("mask.nii.gz") 2518 >>> sr_result = inference(img, model, segmentation=seg) 2519 >>> srimg = sr_result['super_resolution'] 2520 """ 2521 import ants 2522 import numpy as np 2523 import antspynet 2524 import antspyt1w 2525 from siq import region_wise_super_resolution 2526 2527 def apply_intensity_match(sr_image, reference_image, order, verbose=False): 2528 if order is None: 2529 return sr_image 2530 if verbose: 2531 print("Applying intensity match with", order) 2532 if order == 'hist': 2533 return ants.histogram_match_image(sr_image, reference_image) 2534 else: 2535 return ants.regression_match_image(sr_image, reference_image, poly_order=order) 2536 2537 pimg = ants.image_clone(image) 2538 if truncation is not None: 2539 pimg = ants.iMath(pimg, 'TruncateIntensity', truncation[0], truncation[1]) 2540 2541 input_shape = mdl.inputs[0].shape 2542 num_channels = int(input_shape[-1]) 2543 2544 if segmentation is not None: 2545 if num_channels == 1: 2546 if verbose: 2547 print("Using region-wise super resolution due to single-channel model with segmentation.") 2548 sr = region_wise_super_resolution_blended( 2549 pimg, segmentation, mdl, 2550 dilation_amount=dilation_amount, 2551 verbose=verbose 2552 ) 2553 ref = ants.resample_image_to_target(pimg, sr) 2554 return apply_intensity_match(sr, ref, poly_order, verbose) 2555 else: 2556 mynp = segmentation.numpy() 2557 mynp = list(np.unique(mynp)[1:len(mynp)].astype(int)) 2558 upFactor = [] 2559 if len(input_shape) == 5: 2560 testarr = np.zeros([1, 8, 8, 8, 2]) 2561 testarrout = mdl(testarr) 2562 for k in range(3): 2563 upFactor.append(int(testarrout.shape[k + 1] / testarr.shape[k + 1])) 2564 elif len(input_shape) == 4: 2565 testarr = np.zeros([1, 8, 8, 2]) 2566 testarrout = mdl(testarr) 2567 for k in range(2): 2568 upFactor.append(int(testarrout.shape[k + 1] / testarr.shape[k + 1])) 2569 temp = antspyt1w.super_resolution_segmentation_per_label( 2570 pimg, 2571 segmentation, 2572 upFactor, 2573 mdl, 2574 segmentation_numbers=mynp, 2575 target_range=target_range, 2576 dilation_amount=dilation_amount, 2577 poly_order=poly_order, 2578 max_lab_plus_one=True 2579 ) 2580 imgsr = temp['super_resolution'] 2581 ref = ants.resample_image_to_target(pimg, imgsr) 2582 return apply_intensity_match(imgsr, ref, poly_order, verbose) 2583 2584 # Default path: no segmentation 2585 imgsr = antspynet.apply_super_resolution_model_to_image( 2586 pimg, mdl, target_range=target_range, regression_order=None, verbose=verbose 2587 ) 2588 ref = ants.resample_image_to_target(pimg, imgsr) 2589 return apply_intensity_match(imgsr, ref, poly_order, verbose)
Perform super-resolution inference on an input image, optionally guided by segmentation.
This function uses a trained deep learning model to enhance the resolution of a medical image. It optionally applies label-wise inference if a segmentation mask is provided.
Parameters
image : ants.ANTsImage Input image to be super-resolved.
mdl : keras.Model Trained super-resolution model, typically from ANTsPyNet.
truncation : tuple or list of float, optional Percentile values (e.g., [0.01, 0.99]) for intensity truncation before model input. If None, no truncation is applied.
segmentation : ants.ANTsImage, optional
A labeled segmentation mask. If provided, super-resolution is performed per label
using region_wise_super_resolution
or super_resolution_segmentation_per_label
.
target_range : list of float
Intensity range used for scaling the input before applying the model.
Default is [1, 0] (internal default for apply_super_resolution_model_to_image
).
poly_order : int, str or None Determines how to match intensity between the super-resolved image and the original. Options: - 'hist' : use histogram matching - int >= 1 : perform polynomial regression of this order - None : no intensity adjustment
dilation_amount : int Number of dilation steps applied to each segmentation label during region-based super-resolution (if segmentation is provided).
verbose : bool If True, print progress and status messages.
Returns
ANTsImage or dict
- If segmentation
is None, returns a single ANTsImage (super-resolved image).
- If segmentation
is provided, returns a dictionary with:
- 'super_resolution': ANTsImage
- other entries may include label-wise results or metadata.
Examples
>>> import ants
>>> import antspynet
>>> from siq import inference
>>> img = ants.image_read("lowres.nii.gz")
>>> model = antspynet.get_pretrained_network("dbpn", target_suffix="T1")
>>> srimg = inference(img, model, truncation=[0.01, 0.99], verbose=True)
>>> seg = ants.image_read("mask.nii.gz")
>>> sr_result = inference(img, model, segmentation=seg)
>>> srimg = sr_result['super_resolution']