image_to_image.models.transformer
Module to define basic Transformer parts.
Also defines a Transformer model for image-to-image tasks, the PhysicsFormer.
Classes:
- PatchEmbedding
- PositionalEncoding
- Attention
- TransformerEncoderBlock
- PhysicsFormer
By Tobia Ippolito
1""" 2Module to define basic Transformer parts.<br> 3Also defines a Transformer model for image-to-image tasks, the PhysicsFormer. 4 5Classes: 6- PatchEmbedding 7- PositionalEncoding 8- Attention 9- TransformerEncoderBlock 10- PhysicsFormer 11 12By Tobia Ippolito 13""" 14# --------------------------- 15# > Imports < 16# --------------------------- 17import torch 18import torch.nn as nn 19import torch.nn.functional as F 20 21 22 23# --------------------------- 24# > Patching < 25# --------------------------- 26class PatchEmbedding(nn.Module): 27 """ 28 Converts an Image into Patches ('Tokens' in NLP). 29 30 Image size must be: H x W x C 31 32 Patch size must be: P x P 33 """ 34 def __init__(self, img_size, patch_size=16, input_channels=1, embedded_dim=768): 35 """ 36 Init of Patching. 37 38 Each filter creates one new value for each patch 39 and this with embedded_dim-filters. <br> 40 So one patch is projected into a 'embedded_dim'-vector. <br> 41 For example the value at [0, 0] on each channel in the embedded image is together 42 the embedded vector of the first patch. 43 44 Parameter: 45 - img_size (int): 46 Size (width or height) of your image. Your image must have the same width and height. 47 - patch_size (int, default=16): 48 Size (width or height) of one patch. 49 - input_channels (int): 50 Number of Input Channels to be expected. 51 - embedded_dim (int): 52 Output channels / channels of the embedding. 53 """ 54 super().__init__() 55 self.img_size = img_size 56 self.patch_size = patch_size 57 self.in_channels = input_channels 58 self.embedded_dim = embedded_dim 59 60 self.projection = nn.Conv2d(input_channels, embedded_dim, kernel_size=patch_size, stride=patch_size) 61 62 def forward(self, x): 63 """ 64 Forward pass of patching. 65 66 Parameter: 67 - x (torch.tensor): 68 Input Image(s). 69 70 Returns: 71 - tuple(torch.tensor, tuple(int, int)): 72 The Embedded image with the height and width. 73 """ 74 # Input: (batch_size, in_channels, height, width) 75 x = self.projection(x) # (batch_size, embedded_size, num_patches/2, num_patches/2) 76 B, C, H, W = x.shape 77 x = x.flatten(2) # (batch_size, embedded_size, num_patches) 78 x = x.transpose(1, 2) # (batch_size, num_patches, embedded_size) 79 return x, (H, W) 80 81 82 83# --------------------------- 84# > Positional Encoding < 85# --------------------------- 86class PositionalEncoding(nn.Module): 87 """ 88 Add learnable parameters which adds positional information of the patches 89 the position of a patch is important, because a picture makes only sense 90 if the other of sub pictures (patches) is right. 91 """ 92 def __init__(self, num_patches, embedded_dim=768): 93 """ 94 Init of Positonal Encoding. 95 96 Parameter: 97 - num_patches (int): 98 Amount of Patches ('Tokens'). 99 - embedded_dim (int, default=768): 100 Get amount of the embedding channels. 101 """ 102 super().__init__() 103 self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches, embedded_dim)) 104 105 def forward(self, x): 106 """ 107 Forward pass of positional information adding. 108 109 Parameter: 110 - x (torch.tensor): 111 Patch Embedded Image(s). 112 113 Returns: 114 - torch.tensor: 115 The Embedded image(s) with positional encoding added. 116 """ 117 # use broadcast addition 118 return x + self.positional_embedding 119 120 121 122# --------------------------- 123# > Attention < 124# --------------------------- 125class Attention(nn.Module): 126 """ 127 Basic element of Transformer are the attention-layer. <br> 128 Attention layer computes relations to all patches. <br> 129 This is done by calculating the similarity between 2 learnable vectors ! and K. 130 """ 131 def __init__(self, embedded_dim, num_heads): 132 """ 133 Init of Attention Layer. 134 135 Parameter: 136 - embedded_dim (int): 137 Patch Embedding Channels. 138 - num_heads (int): 139 Number of parallel attention computations. 140 """ 141 super().__init__() 142 assert embedded_dim % num_heads == 0, \ 143 f"embedded_dim ({embedded_dim}) must be divisible by num_heads ({num_heads})" 144 145 self.num_heads = num_heads 146 self.head_dim = embedded_dim // num_heads 147 self.scale = self.head_dim**-0.5 # factor for scaling 148 self.qkv = nn.Linear(embedded_dim, embedded_dim*3) 149 self.fc = nn.Linear(embedded_dim, embedded_dim) 150 151 def forward(self, x): 152 """ 153 Forward pass of Attention Layer. 154 155 softmax(QK^T)V 156 157 Parameter: 158 - x (torch.tensor): 159 Patch Embedded Image(s) with positional encoding added. 160 161 Returns: 162 - torch.tensor: 163 The attention cores passed through fully connected layer. 164 """ 165 batch_size, num_patches, embedded_dim = x.shape 166 qkv = self.qkv(x) # (batch_size, num_patches, embedded_dim*3) 167 qkv = qkv.reshape(batch_size, num_patches, 3, self.num_heads, self.head_dim) 168 qkv = qkv.permute(2, 0, 3, 1, 4) # -(3, batch_size, num_heads, num_patches, head_dim) 169 q, k, v = qkv.unbind(dim=0) 170 171 # compute scaled dot-product 172 # attention_scores = (q@k.transpose(-2, -1)) * self.scale # (batch_size, num_heads, num_patches, num_patches) 173 # attention_weights = attention_scores.softmax(dim=-1) 174 # attention_output = (attention_weights@v).reshape(batch_size, num_patches, embedded_dim) 175 176 # -> Memory-efficient attention score version (uses flash-attention when available) 177 attention_output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) 178 attention_output = attention_output.transpose(1, 2).reshape( 179 batch_size, num_patches, embedded_dim 180 ) 181 182 return self.fc(attention_output) 183 184 185 186# --------------------------- 187# > Transformer Encoder Block < 188# --------------------------- 189class TransformerEncoderBlock(nn.Module): 190 """ 191 A Transformer Encoder Block consists of self-attention layer 192 followed by a feedforward layer (mlp = multilayerperceptron) 193 + (layer) normalization 194 and with residual connections / skip connections. 195 """ 196 def __init__(self, embedded_dim, num_heads, mlp_dim, dropout=0.1): 197 """ 198 Init of a Transformer Block. 199 200 Parameter: 201 - embedded_dim (int): 202 Patch Embedding Channels. 203 - num_heads (int): 204 Number of parallel attention computations. 205 - mlp_dim (int): 206 Hidden/Feature dimension of the multi layer perceptron layer. 207 - dropout (float): 208 Propability of droput regulization. 209 """ 210 super().__init__() 211 self.norm_1 = nn.LayerNorm(embedded_dim) 212 self.attention = Attention(embedded_dim=embedded_dim, num_heads=num_heads) 213 214 self.norm_2 = nn.LayerNorm(embedded_dim) 215 # hidden_dim = int(embedded_dim * mlp_ratio) 216 self.mlp = nn.Sequential( 217 nn.Linear(embedded_dim, mlp_dim), 218 nn.GELU(), 219 nn.Linear(mlp_dim, embedded_dim), 220 nn.Dropout(dropout) 221 ) 222 223 224 def forward(self, x): 225 """ 226 Forward pass of Transformer Block. 227 228 Parameter: 229 - x (torch.tensor): 230 Patch Embedded Image(s) with positional encoding added. 231 232 Returns: 233 - torch.tensor: 234 The attention cores passed through fully connected layer and a multilayer perceptron with layer normalization. 235 """ 236 # self attention with skip connection 237 x = self.norm_1(x) 238 x = x + self.attention(x) 239 240 # MLP with skip connection 241 x = self.norm_2(x) 242 x = x + self.mlp(x) 243 244 return x 245 246 247 248# --------------------------- 249# > CNN Refinement < 250# --------------------------- 251class CNNRefinement(nn.Module): 252 """ 253 Refinement Network to remove transformer artefacts. 254 """ 255 def __init__(self, input_channels=1, hidden_channels=64, output_channels=1, skip_connection=True): 256 """ 257 Init of a CNN Refinement network. 258 259 Parameter: 260 - input_channels (int): 261 Number of input channels. 262 - hidden_channels (int): 263 Number of hidden/feature channels. 264 - output_channels (int): 265 Number of output channels. 266 - skip_connection (bool): 267 Should a skip connection be used? Means if a second input (the original image) should be added to the output. 268 That changes the CNN network to learning a correction which will be applied to the original image. 269 """ 270 super().__init__() 271 272 self.conv_1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=3, stride=1, padding=1) 273 self.activation_1 = nn.ReLU() 274 275 self.conv_2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1) 276 self.activation_2 = nn.ReLU() 277 278 self.conv_3 = nn.Conv2d(hidden_channels, output_channels, kernel_size=3, stride=1, padding=1) 279 280 # or just: 281 # nn.Sequential( 282 # nn.Conv2d(output_channels, 64, 3, padding=1), 283 # nn.ReLU(), 284 # nn.Conv2d(64, 64, 3, padding=1), 285 # nn.ReLU(), 286 # nn.Conv2d(64, output_channels, 3, padding=1) 287 # ) 288 289 self.skip_connection = skip_connection 290 291 292 def forward(self, x, original=None): 293 """ 294 Forward pass of CNN Refinement network. 295 296 Parameter: 297 - x (torch.tensor): 298 Patch Embedded Image(s) with positional encoding added. 299 - original (torch.tensor or None, default=None): 300 Original image which will be added at the end if skip_connection is set to true. 301 302 Returns: 303 - torch.tensor: 304 The refined image. 305 """ 306 x = self.conv_1(x) 307 x = self.activation_1(x) 308 309 x = self.conv_2(x) 310 x = self.activation_2(x) 311 312 x = self.conv_3(x) 313 314 if self.skip_connection and original is not None: 315 # apply correction to image 316 x = x + original 317 318 return x 319 320# --------------------------- 321# > Img2Img Transformer < 322# --------------------------- 323# The whole model consists of: 324# - patching (tokenizing) 325# - add positional encoding 326# - Transformer Blocks 327class PhysicFormer(nn.Module): 328 """ 329 Image-to-Image Transformer. 330 331 The whole model consists of: 332 - Patching (tokenizing) 333 - Add positional encoding 334 - Transformer Blocks (Attention + MLP) 335 - Image Reconstruction/Remapping -> Embedded Space to Pixel Space 336 - CNN Refinement 337 338 Model logic:<br> 339 - `CNN(Transformer(x))` = Pure Transformation (skip connection = false) 340 - `CNN(Transformer(x)) + x_input` = residual refinement (skip connection = true) 341 - `Transformer(x) + CNN(Transformer(x)) + x_input` = global field + local correction + geometry/residual (not available yet) 342 """ 343 def __init__(self, input_channels=1, output_channels=1, 344 img_size=256, patch_size=4, 345 embedded_dim=1026, num_blocks=8, 346 heads=16, mlp_dim=2048, dropout=0.1): 347 """ 348 Init of the PhysicFormer model. 349 350 Parameter: 351 - input_channels (int): 352 Number of input image channels (e.g., 1 for grayscale, 3 for RGB). 353 - output_channels (int): 354 Number of output image channels. 355 - img_size (int): 356 Size (height and width) of the input image in pixels. 357 - patch_size (int): 358 Size of each square patch to split the image into. 359 The image must be divisible by this size. 360 - embedded_dim (int): 361 Dimension of the patch embedding (feature space size per token). 362 - num_blocks (int): 363 Number of Transformer Encoder blocks. 364 - heads (int): 365 Number of attention heads per Attention layer. 366 - mlp_dim (int): 367 Hidden dimension of the feed-forward (MLP) layers within each Transformer block. 368 - dropout (float): 369 Dropout probability for regularization applied after positional encoding and inside MLP. 370 """ 371 super().__init__() 372 self.input_channels = input_channels 373 self.output_channels = output_channels 374 self.patch_size = patch_size 375 376 self.patch_embedding = PatchEmbedding(img_size=img_size, 377 patch_size=patch_size, 378 input_channels=input_channels, 379 embedded_dim=embedded_dim) 380 381 num_patches = (img_size//patch_size) * (img_size//patch_size) 382 self.positional_encoding = PositionalEncoding(num_patches=num_patches, embedded_dim=embedded_dim) 383 384 self.dropout = nn.Dropout(dropout) 385 386 blocks = [] 387 for _ in range(num_blocks): 388 blocks += [TransformerEncoderBlock(embedded_dim=embedded_dim, num_heads=heads, mlp_dim=mlp_dim, dropout=dropout)] 389 self.transformer_blocks = nn.ModuleList(blocks) 390 391 self.to_img = nn.Sequential( 392 nn.Linear(embedded_dim, patch_size*patch_size*output_channels) 393 ) 394 395 self.norm = nn.LayerNorm(embedded_dim) 396 397 self.refinement = CNNRefinement(input_channels=output_channels, hidden_channels=64, output_channels=output_channels, skip_connection=True) 398 399 400 def get_input_channels(self): 401 """ 402 Returns the number of input channels used by the model. 403 404 Returns: 405 - int: 406 Number of input channels expected by the model. 407 """ 408 return self.input_channels 409 410 def get_output_channels(self): 411 """ 412 Returns the number of output channels produced by the model. 413 414 Returns: 415 - int: 416 Number of output channels the model generates 417 """ 418 return self.output_channels 419 420 421 def forward(self, x): 422 """ 423 Forward pass of the PhysicFormer network. 424 425 Parameter: 426 - x (torch.tensor): 427 Input image tensor of shape (batch_size, input_channels, height, width). 428 429 Returns: 430 - torch.tensor: 431 Refined output image tensor of shape (batch_size, output_channels, height, width), 432 with values normalized to [0.0, 1.0]. 433 434 Notes: 435 - The output passes through a `sigmoid()` activation, ensuring all pixel values ∈ [0, 1]. 436 - Designed for physics-informed or visual reconstruction tasks where local and global consistency are important. 437 """ 438 x_input = x 439 440 # patch embedding / tokenization 441 x, (height, width) = self.patch_embedding(x) 442 443 # encoding / add positional information 444 x = self.positional_encoding(x) 445 446 x = self.dropout(x) 447 448 # transformer blocks 449 for transformer_block in self.transformer_blocks: 450 x = transformer_block(x) 451 452 x = self.norm(x) 453 454 # translation to image 455 x = self.to_img(x) 456 457 # return it in the right format: [B, C, H, W] 458 x = x.transpose(1, 2).reshape(x.shape[0], self.output_channels, self.patch_size*height, self.patch_size*width) 459 # when you call .view() right after .transpose(), PyTorch can’t reinterpret the data layout -> this is an error. 460 461 # refinement 462 x = self.refinement(x, original=x_input) 463 464 # Other version: 465 # refined = self.refinement(x) 466 467 # # Combine contributions (global + local + input) 468 # x = x + refined + x_input 469 470 return torch.sigmoid(x) # between 0.0 and 1.0 -> alt: torch.clamp(x, 0.0, 1.0)
27class PatchEmbedding(nn.Module): 28 """ 29 Converts an Image into Patches ('Tokens' in NLP). 30 31 Image size must be: H x W x C 32 33 Patch size must be: P x P 34 """ 35 def __init__(self, img_size, patch_size=16, input_channels=1, embedded_dim=768): 36 """ 37 Init of Patching. 38 39 Each filter creates one new value for each patch 40 and this with embedded_dim-filters. <br> 41 So one patch is projected into a 'embedded_dim'-vector. <br> 42 For example the value at [0, 0] on each channel in the embedded image is together 43 the embedded vector of the first patch. 44 45 Parameter: 46 - img_size (int): 47 Size (width or height) of your image. Your image must have the same width and height. 48 - patch_size (int, default=16): 49 Size (width or height) of one patch. 50 - input_channels (int): 51 Number of Input Channels to be expected. 52 - embedded_dim (int): 53 Output channels / channels of the embedding. 54 """ 55 super().__init__() 56 self.img_size = img_size 57 self.patch_size = patch_size 58 self.in_channels = input_channels 59 self.embedded_dim = embedded_dim 60 61 self.projection = nn.Conv2d(input_channels, embedded_dim, kernel_size=patch_size, stride=patch_size) 62 63 def forward(self, x): 64 """ 65 Forward pass of patching. 66 67 Parameter: 68 - x (torch.tensor): 69 Input Image(s). 70 71 Returns: 72 - tuple(torch.tensor, tuple(int, int)): 73 The Embedded image with the height and width. 74 """ 75 # Input: (batch_size, in_channels, height, width) 76 x = self.projection(x) # (batch_size, embedded_size, num_patches/2, num_patches/2) 77 B, C, H, W = x.shape 78 x = x.flatten(2) # (batch_size, embedded_size, num_patches) 79 x = x.transpose(1, 2) # (batch_size, num_patches, embedded_size) 80 return x, (H, W)
Converts an Image into Patches ('Tokens' in NLP).
Image size must be: H x W x C
Patch size must be: P x P
35 def __init__(self, img_size, patch_size=16, input_channels=1, embedded_dim=768): 36 """ 37 Init of Patching. 38 39 Each filter creates one new value for each patch 40 and this with embedded_dim-filters. <br> 41 So one patch is projected into a 'embedded_dim'-vector. <br> 42 For example the value at [0, 0] on each channel in the embedded image is together 43 the embedded vector of the first patch. 44 45 Parameter: 46 - img_size (int): 47 Size (width or height) of your image. Your image must have the same width and height. 48 - patch_size (int, default=16): 49 Size (width or height) of one patch. 50 - input_channels (int): 51 Number of Input Channels to be expected. 52 - embedded_dim (int): 53 Output channels / channels of the embedding. 54 """ 55 super().__init__() 56 self.img_size = img_size 57 self.patch_size = patch_size 58 self.in_channels = input_channels 59 self.embedded_dim = embedded_dim 60 61 self.projection = nn.Conv2d(input_channels, embedded_dim, kernel_size=patch_size, stride=patch_size)
Init of Patching.
Each filter creates one new value for each patch
and this with embedded_dim-filters.
So one patch is projected into a 'embedded_dim'-vector.
For example the value at [0, 0] on each channel in the embedded image is together
the embedded vector of the first patch.
Parameter:
- img_size (int): Size (width or height) of your image. Your image must have the same width and height.
- patch_size (int, default=16): Size (width or height) of one patch.
- input_channels (int): Number of Input Channels to be expected.
- embedded_dim (int): Output channels / channels of the embedding.
63 def forward(self, x): 64 """ 65 Forward pass of patching. 66 67 Parameter: 68 - x (torch.tensor): 69 Input Image(s). 70 71 Returns: 72 - tuple(torch.tensor, tuple(int, int)): 73 The Embedded image with the height and width. 74 """ 75 # Input: (batch_size, in_channels, height, width) 76 x = self.projection(x) # (batch_size, embedded_size, num_patches/2, num_patches/2) 77 B, C, H, W = x.shape 78 x = x.flatten(2) # (batch_size, embedded_size, num_patches) 79 x = x.transpose(1, 2) # (batch_size, num_patches, embedded_size) 80 return x, (H, W)
Forward pass of patching.
Parameter:
- x (torch.tensor): Input Image(s).
Returns:
- tuple(torch.tensor, tuple(int, int)): The Embedded image with the height and width.
87class PositionalEncoding(nn.Module): 88 """ 89 Add learnable parameters which adds positional information of the patches 90 the position of a patch is important, because a picture makes only sense 91 if the other of sub pictures (patches) is right. 92 """ 93 def __init__(self, num_patches, embedded_dim=768): 94 """ 95 Init of Positonal Encoding. 96 97 Parameter: 98 - num_patches (int): 99 Amount of Patches ('Tokens'). 100 - embedded_dim (int, default=768): 101 Get amount of the embedding channels. 102 """ 103 super().__init__() 104 self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches, embedded_dim)) 105 106 def forward(self, x): 107 """ 108 Forward pass of positional information adding. 109 110 Parameter: 111 - x (torch.tensor): 112 Patch Embedded Image(s). 113 114 Returns: 115 - torch.tensor: 116 The Embedded image(s) with positional encoding added. 117 """ 118 # use broadcast addition 119 return x + self.positional_embedding
Add learnable parameters which adds positional information of the patches the position of a patch is important, because a picture makes only sense if the other of sub pictures (patches) is right.
93 def __init__(self, num_patches, embedded_dim=768): 94 """ 95 Init of Positonal Encoding. 96 97 Parameter: 98 - num_patches (int): 99 Amount of Patches ('Tokens'). 100 - embedded_dim (int, default=768): 101 Get amount of the embedding channels. 102 """ 103 super().__init__() 104 self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches, embedded_dim))
Init of Positonal Encoding.
Parameter:
- num_patches (int): Amount of Patches ('Tokens').
- embedded_dim (int, default=768): Get amount of the embedding channels.
106 def forward(self, x): 107 """ 108 Forward pass of positional information adding. 109 110 Parameter: 111 - x (torch.tensor): 112 Patch Embedded Image(s). 113 114 Returns: 115 - torch.tensor: 116 The Embedded image(s) with positional encoding added. 117 """ 118 # use broadcast addition 119 return x + self.positional_embedding
Forward pass of positional information adding.
Parameter:
- x (torch.tensor): Patch Embedded Image(s).
Returns:
- torch.tensor: The Embedded image(s) with positional encoding added.
126class Attention(nn.Module): 127 """ 128 Basic element of Transformer are the attention-layer. <br> 129 Attention layer computes relations to all patches. <br> 130 This is done by calculating the similarity between 2 learnable vectors ! and K. 131 """ 132 def __init__(self, embedded_dim, num_heads): 133 """ 134 Init of Attention Layer. 135 136 Parameter: 137 - embedded_dim (int): 138 Patch Embedding Channels. 139 - num_heads (int): 140 Number of parallel attention computations. 141 """ 142 super().__init__() 143 assert embedded_dim % num_heads == 0, \ 144 f"embedded_dim ({embedded_dim}) must be divisible by num_heads ({num_heads})" 145 146 self.num_heads = num_heads 147 self.head_dim = embedded_dim // num_heads 148 self.scale = self.head_dim**-0.5 # factor for scaling 149 self.qkv = nn.Linear(embedded_dim, embedded_dim*3) 150 self.fc = nn.Linear(embedded_dim, embedded_dim) 151 152 def forward(self, x): 153 """ 154 Forward pass of Attention Layer. 155 156 softmax(QK^T)V 157 158 Parameter: 159 - x (torch.tensor): 160 Patch Embedded Image(s) with positional encoding added. 161 162 Returns: 163 - torch.tensor: 164 The attention cores passed through fully connected layer. 165 """ 166 batch_size, num_patches, embedded_dim = x.shape 167 qkv = self.qkv(x) # (batch_size, num_patches, embedded_dim*3) 168 qkv = qkv.reshape(batch_size, num_patches, 3, self.num_heads, self.head_dim) 169 qkv = qkv.permute(2, 0, 3, 1, 4) # -(3, batch_size, num_heads, num_patches, head_dim) 170 q, k, v = qkv.unbind(dim=0) 171 172 # compute scaled dot-product 173 # attention_scores = (q@k.transpose(-2, -1)) * self.scale # (batch_size, num_heads, num_patches, num_patches) 174 # attention_weights = attention_scores.softmax(dim=-1) 175 # attention_output = (attention_weights@v).reshape(batch_size, num_patches, embedded_dim) 176 177 # -> Memory-efficient attention score version (uses flash-attention when available) 178 attention_output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) 179 attention_output = attention_output.transpose(1, 2).reshape( 180 batch_size, num_patches, embedded_dim 181 ) 182 183 return self.fc(attention_output)
Basic element of Transformer are the attention-layer.
Attention layer computes relations to all patches.
This is done by calculating the similarity between 2 learnable vectors ! and K.
132 def __init__(self, embedded_dim, num_heads): 133 """ 134 Init of Attention Layer. 135 136 Parameter: 137 - embedded_dim (int): 138 Patch Embedding Channels. 139 - num_heads (int): 140 Number of parallel attention computations. 141 """ 142 super().__init__() 143 assert embedded_dim % num_heads == 0, \ 144 f"embedded_dim ({embedded_dim}) must be divisible by num_heads ({num_heads})" 145 146 self.num_heads = num_heads 147 self.head_dim = embedded_dim // num_heads 148 self.scale = self.head_dim**-0.5 # factor for scaling 149 self.qkv = nn.Linear(embedded_dim, embedded_dim*3) 150 self.fc = nn.Linear(embedded_dim, embedded_dim)
Init of Attention Layer.
Parameter:
- embedded_dim (int): Patch Embedding Channels.
- num_heads (int): Number of parallel attention computations.
152 def forward(self, x): 153 """ 154 Forward pass of Attention Layer. 155 156 softmax(QK^T)V 157 158 Parameter: 159 - x (torch.tensor): 160 Patch Embedded Image(s) with positional encoding added. 161 162 Returns: 163 - torch.tensor: 164 The attention cores passed through fully connected layer. 165 """ 166 batch_size, num_patches, embedded_dim = x.shape 167 qkv = self.qkv(x) # (batch_size, num_patches, embedded_dim*3) 168 qkv = qkv.reshape(batch_size, num_patches, 3, self.num_heads, self.head_dim) 169 qkv = qkv.permute(2, 0, 3, 1, 4) # -(3, batch_size, num_heads, num_patches, head_dim) 170 q, k, v = qkv.unbind(dim=0) 171 172 # compute scaled dot-product 173 # attention_scores = (q@k.transpose(-2, -1)) * self.scale # (batch_size, num_heads, num_patches, num_patches) 174 # attention_weights = attention_scores.softmax(dim=-1) 175 # attention_output = (attention_weights@v).reshape(batch_size, num_patches, embedded_dim) 176 177 # -> Memory-efficient attention score version (uses flash-attention when available) 178 attention_output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) 179 attention_output = attention_output.transpose(1, 2).reshape( 180 batch_size, num_patches, embedded_dim 181 ) 182 183 return self.fc(attention_output)
Forward pass of Attention Layer.
softmax(QK^T)V
Parameter:
- x (torch.tensor): Patch Embedded Image(s) with positional encoding added.
Returns:
- torch.tensor: The attention cores passed through fully connected layer.
190class TransformerEncoderBlock(nn.Module): 191 """ 192 A Transformer Encoder Block consists of self-attention layer 193 followed by a feedforward layer (mlp = multilayerperceptron) 194 + (layer) normalization 195 and with residual connections / skip connections. 196 """ 197 def __init__(self, embedded_dim, num_heads, mlp_dim, dropout=0.1): 198 """ 199 Init of a Transformer Block. 200 201 Parameter: 202 - embedded_dim (int): 203 Patch Embedding Channels. 204 - num_heads (int): 205 Number of parallel attention computations. 206 - mlp_dim (int): 207 Hidden/Feature dimension of the multi layer perceptron layer. 208 - dropout (float): 209 Propability of droput regulization. 210 """ 211 super().__init__() 212 self.norm_1 = nn.LayerNorm(embedded_dim) 213 self.attention = Attention(embedded_dim=embedded_dim, num_heads=num_heads) 214 215 self.norm_2 = nn.LayerNorm(embedded_dim) 216 # hidden_dim = int(embedded_dim * mlp_ratio) 217 self.mlp = nn.Sequential( 218 nn.Linear(embedded_dim, mlp_dim), 219 nn.GELU(), 220 nn.Linear(mlp_dim, embedded_dim), 221 nn.Dropout(dropout) 222 ) 223 224 225 def forward(self, x): 226 """ 227 Forward pass of Transformer Block. 228 229 Parameter: 230 - x (torch.tensor): 231 Patch Embedded Image(s) with positional encoding added. 232 233 Returns: 234 - torch.tensor: 235 The attention cores passed through fully connected layer and a multilayer perceptron with layer normalization. 236 """ 237 # self attention with skip connection 238 x = self.norm_1(x) 239 x = x + self.attention(x) 240 241 # MLP with skip connection 242 x = self.norm_2(x) 243 x = x + self.mlp(x) 244 245 return x
A Transformer Encoder Block consists of self-attention layer followed by a feedforward layer (mlp = multilayerperceptron)
- (layer) normalization and with residual connections / skip connections.
197 def __init__(self, embedded_dim, num_heads, mlp_dim, dropout=0.1): 198 """ 199 Init of a Transformer Block. 200 201 Parameter: 202 - embedded_dim (int): 203 Patch Embedding Channels. 204 - num_heads (int): 205 Number of parallel attention computations. 206 - mlp_dim (int): 207 Hidden/Feature dimension of the multi layer perceptron layer. 208 - dropout (float): 209 Propability of droput regulization. 210 """ 211 super().__init__() 212 self.norm_1 = nn.LayerNorm(embedded_dim) 213 self.attention = Attention(embedded_dim=embedded_dim, num_heads=num_heads) 214 215 self.norm_2 = nn.LayerNorm(embedded_dim) 216 # hidden_dim = int(embedded_dim * mlp_ratio) 217 self.mlp = nn.Sequential( 218 nn.Linear(embedded_dim, mlp_dim), 219 nn.GELU(), 220 nn.Linear(mlp_dim, embedded_dim), 221 nn.Dropout(dropout) 222 )
Init of a Transformer Block.
Parameter:
- embedded_dim (int): Patch Embedding Channels.
- num_heads (int): Number of parallel attention computations.
- mlp_dim (int): Hidden/Feature dimension of the multi layer perceptron layer.
- dropout (float): Propability of droput regulization.
225 def forward(self, x): 226 """ 227 Forward pass of Transformer Block. 228 229 Parameter: 230 - x (torch.tensor): 231 Patch Embedded Image(s) with positional encoding added. 232 233 Returns: 234 - torch.tensor: 235 The attention cores passed through fully connected layer and a multilayer perceptron with layer normalization. 236 """ 237 # self attention with skip connection 238 x = self.norm_1(x) 239 x = x + self.attention(x) 240 241 # MLP with skip connection 242 x = self.norm_2(x) 243 x = x + self.mlp(x) 244 245 return x
Forward pass of Transformer Block.
Parameter:
- x (torch.tensor): Patch Embedded Image(s) with positional encoding added.
Returns:
- torch.tensor: The attention cores passed through fully connected layer and a multilayer perceptron with layer normalization.
252class CNNRefinement(nn.Module): 253 """ 254 Refinement Network to remove transformer artefacts. 255 """ 256 def __init__(self, input_channels=1, hidden_channels=64, output_channels=1, skip_connection=True): 257 """ 258 Init of a CNN Refinement network. 259 260 Parameter: 261 - input_channels (int): 262 Number of input channels. 263 - hidden_channels (int): 264 Number of hidden/feature channels. 265 - output_channels (int): 266 Number of output channels. 267 - skip_connection (bool): 268 Should a skip connection be used? Means if a second input (the original image) should be added to the output. 269 That changes the CNN network to learning a correction which will be applied to the original image. 270 """ 271 super().__init__() 272 273 self.conv_1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=3, stride=1, padding=1) 274 self.activation_1 = nn.ReLU() 275 276 self.conv_2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1) 277 self.activation_2 = nn.ReLU() 278 279 self.conv_3 = nn.Conv2d(hidden_channels, output_channels, kernel_size=3, stride=1, padding=1) 280 281 # or just: 282 # nn.Sequential( 283 # nn.Conv2d(output_channels, 64, 3, padding=1), 284 # nn.ReLU(), 285 # nn.Conv2d(64, 64, 3, padding=1), 286 # nn.ReLU(), 287 # nn.Conv2d(64, output_channels, 3, padding=1) 288 # ) 289 290 self.skip_connection = skip_connection 291 292 293 def forward(self, x, original=None): 294 """ 295 Forward pass of CNN Refinement network. 296 297 Parameter: 298 - x (torch.tensor): 299 Patch Embedded Image(s) with positional encoding added. 300 - original (torch.tensor or None, default=None): 301 Original image which will be added at the end if skip_connection is set to true. 302 303 Returns: 304 - torch.tensor: 305 The refined image. 306 """ 307 x = self.conv_1(x) 308 x = self.activation_1(x) 309 310 x = self.conv_2(x) 311 x = self.activation_2(x) 312 313 x = self.conv_3(x) 314 315 if self.skip_connection and original is not None: 316 # apply correction to image 317 x = x + original 318 319 return x
Refinement Network to remove transformer artefacts.
256 def __init__(self, input_channels=1, hidden_channels=64, output_channels=1, skip_connection=True): 257 """ 258 Init of a CNN Refinement network. 259 260 Parameter: 261 - input_channels (int): 262 Number of input channels. 263 - hidden_channels (int): 264 Number of hidden/feature channels. 265 - output_channels (int): 266 Number of output channels. 267 - skip_connection (bool): 268 Should a skip connection be used? Means if a second input (the original image) should be added to the output. 269 That changes the CNN network to learning a correction which will be applied to the original image. 270 """ 271 super().__init__() 272 273 self.conv_1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=3, stride=1, padding=1) 274 self.activation_1 = nn.ReLU() 275 276 self.conv_2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1) 277 self.activation_2 = nn.ReLU() 278 279 self.conv_3 = nn.Conv2d(hidden_channels, output_channels, kernel_size=3, stride=1, padding=1) 280 281 # or just: 282 # nn.Sequential( 283 # nn.Conv2d(output_channels, 64, 3, padding=1), 284 # nn.ReLU(), 285 # nn.Conv2d(64, 64, 3, padding=1), 286 # nn.ReLU(), 287 # nn.Conv2d(64, output_channels, 3, padding=1) 288 # ) 289 290 self.skip_connection = skip_connection
Init of a CNN Refinement network.
Parameter:
- input_channels (int): Number of input channels.
- hidden_channels (int): Number of hidden/feature channels.
- output_channels (int): Number of output channels.
- skip_connection (bool): Should a skip connection be used? Means if a second input (the original image) should be added to the output. That changes the CNN network to learning a correction which will be applied to the original image.
293 def forward(self, x, original=None): 294 """ 295 Forward pass of CNN Refinement network. 296 297 Parameter: 298 - x (torch.tensor): 299 Patch Embedded Image(s) with positional encoding added. 300 - original (torch.tensor or None, default=None): 301 Original image which will be added at the end if skip_connection is set to true. 302 303 Returns: 304 - torch.tensor: 305 The refined image. 306 """ 307 x = self.conv_1(x) 308 x = self.activation_1(x) 309 310 x = self.conv_2(x) 311 x = self.activation_2(x) 312 313 x = self.conv_3(x) 314 315 if self.skip_connection and original is not None: 316 # apply correction to image 317 x = x + original 318 319 return x
Forward pass of CNN Refinement network.
Parameter:
- x (torch.tensor): Patch Embedded Image(s) with positional encoding added.
- original (torch.tensor or None, default=None): Original image which will be added at the end if skip_connection is set to true.
Returns:
- torch.tensor: The refined image.
328class PhysicFormer(nn.Module): 329 """ 330 Image-to-Image Transformer. 331 332 The whole model consists of: 333 - Patching (tokenizing) 334 - Add positional encoding 335 - Transformer Blocks (Attention + MLP) 336 - Image Reconstruction/Remapping -> Embedded Space to Pixel Space 337 - CNN Refinement 338 339 Model logic:<br> 340 - `CNN(Transformer(x))` = Pure Transformation (skip connection = false) 341 - `CNN(Transformer(x)) + x_input` = residual refinement (skip connection = true) 342 - `Transformer(x) + CNN(Transformer(x)) + x_input` = global field + local correction + geometry/residual (not available yet) 343 """ 344 def __init__(self, input_channels=1, output_channels=1, 345 img_size=256, patch_size=4, 346 embedded_dim=1026, num_blocks=8, 347 heads=16, mlp_dim=2048, dropout=0.1): 348 """ 349 Init of the PhysicFormer model. 350 351 Parameter: 352 - input_channels (int): 353 Number of input image channels (e.g., 1 for grayscale, 3 for RGB). 354 - output_channels (int): 355 Number of output image channels. 356 - img_size (int): 357 Size (height and width) of the input image in pixels. 358 - patch_size (int): 359 Size of each square patch to split the image into. 360 The image must be divisible by this size. 361 - embedded_dim (int): 362 Dimension of the patch embedding (feature space size per token). 363 - num_blocks (int): 364 Number of Transformer Encoder blocks. 365 - heads (int): 366 Number of attention heads per Attention layer. 367 - mlp_dim (int): 368 Hidden dimension of the feed-forward (MLP) layers within each Transformer block. 369 - dropout (float): 370 Dropout probability for regularization applied after positional encoding and inside MLP. 371 """ 372 super().__init__() 373 self.input_channels = input_channels 374 self.output_channels = output_channels 375 self.patch_size = patch_size 376 377 self.patch_embedding = PatchEmbedding(img_size=img_size, 378 patch_size=patch_size, 379 input_channels=input_channels, 380 embedded_dim=embedded_dim) 381 382 num_patches = (img_size//patch_size) * (img_size//patch_size) 383 self.positional_encoding = PositionalEncoding(num_patches=num_patches, embedded_dim=embedded_dim) 384 385 self.dropout = nn.Dropout(dropout) 386 387 blocks = [] 388 for _ in range(num_blocks): 389 blocks += [TransformerEncoderBlock(embedded_dim=embedded_dim, num_heads=heads, mlp_dim=mlp_dim, dropout=dropout)] 390 self.transformer_blocks = nn.ModuleList(blocks) 391 392 self.to_img = nn.Sequential( 393 nn.Linear(embedded_dim, patch_size*patch_size*output_channels) 394 ) 395 396 self.norm = nn.LayerNorm(embedded_dim) 397 398 self.refinement = CNNRefinement(input_channels=output_channels, hidden_channels=64, output_channels=output_channels, skip_connection=True) 399 400 401 def get_input_channels(self): 402 """ 403 Returns the number of input channels used by the model. 404 405 Returns: 406 - int: 407 Number of input channels expected by the model. 408 """ 409 return self.input_channels 410 411 def get_output_channels(self): 412 """ 413 Returns the number of output channels produced by the model. 414 415 Returns: 416 - int: 417 Number of output channels the model generates 418 """ 419 return self.output_channels 420 421 422 def forward(self, x): 423 """ 424 Forward pass of the PhysicFormer network. 425 426 Parameter: 427 - x (torch.tensor): 428 Input image tensor of shape (batch_size, input_channels, height, width). 429 430 Returns: 431 - torch.tensor: 432 Refined output image tensor of shape (batch_size, output_channels, height, width), 433 with values normalized to [0.0, 1.0]. 434 435 Notes: 436 - The output passes through a `sigmoid()` activation, ensuring all pixel values ∈ [0, 1]. 437 - Designed for physics-informed or visual reconstruction tasks where local and global consistency are important. 438 """ 439 x_input = x 440 441 # patch embedding / tokenization 442 x, (height, width) = self.patch_embedding(x) 443 444 # encoding / add positional information 445 x = self.positional_encoding(x) 446 447 x = self.dropout(x) 448 449 # transformer blocks 450 for transformer_block in self.transformer_blocks: 451 x = transformer_block(x) 452 453 x = self.norm(x) 454 455 # translation to image 456 x = self.to_img(x) 457 458 # return it in the right format: [B, C, H, W] 459 x = x.transpose(1, 2).reshape(x.shape[0], self.output_channels, self.patch_size*height, self.patch_size*width) 460 # when you call .view() right after .transpose(), PyTorch can’t reinterpret the data layout -> this is an error. 461 462 # refinement 463 x = self.refinement(x, original=x_input) 464 465 # Other version: 466 # refined = self.refinement(x) 467 468 # # Combine contributions (global + local + input) 469 # x = x + refined + x_input 470 471 return torch.sigmoid(x) # between 0.0 and 1.0 -> alt: torch.clamp(x, 0.0, 1.0)
Image-to-Image Transformer.
The whole model consists of:
- Patching (tokenizing)
- Add positional encoding
- Transformer Blocks (Attention + MLP)
- Image Reconstruction/Remapping -> Embedded Space to Pixel Space
- CNN Refinement
Model logic:
CNN(Transformer(x))= Pure Transformation (skip connection = false)CNN(Transformer(x)) + x_input= residual refinement (skip connection = true)Transformer(x) + CNN(Transformer(x)) + x_input= global field + local correction + geometry/residual (not available yet)
344 def __init__(self, input_channels=1, output_channels=1, 345 img_size=256, patch_size=4, 346 embedded_dim=1026, num_blocks=8, 347 heads=16, mlp_dim=2048, dropout=0.1): 348 """ 349 Init of the PhysicFormer model. 350 351 Parameter: 352 - input_channels (int): 353 Number of input image channels (e.g., 1 for grayscale, 3 for RGB). 354 - output_channels (int): 355 Number of output image channels. 356 - img_size (int): 357 Size (height and width) of the input image in pixels. 358 - patch_size (int): 359 Size of each square patch to split the image into. 360 The image must be divisible by this size. 361 - embedded_dim (int): 362 Dimension of the patch embedding (feature space size per token). 363 - num_blocks (int): 364 Number of Transformer Encoder blocks. 365 - heads (int): 366 Number of attention heads per Attention layer. 367 - mlp_dim (int): 368 Hidden dimension of the feed-forward (MLP) layers within each Transformer block. 369 - dropout (float): 370 Dropout probability for regularization applied after positional encoding and inside MLP. 371 """ 372 super().__init__() 373 self.input_channels = input_channels 374 self.output_channels = output_channels 375 self.patch_size = patch_size 376 377 self.patch_embedding = PatchEmbedding(img_size=img_size, 378 patch_size=patch_size, 379 input_channels=input_channels, 380 embedded_dim=embedded_dim) 381 382 num_patches = (img_size//patch_size) * (img_size//patch_size) 383 self.positional_encoding = PositionalEncoding(num_patches=num_patches, embedded_dim=embedded_dim) 384 385 self.dropout = nn.Dropout(dropout) 386 387 blocks = [] 388 for _ in range(num_blocks): 389 blocks += [TransformerEncoderBlock(embedded_dim=embedded_dim, num_heads=heads, mlp_dim=mlp_dim, dropout=dropout)] 390 self.transformer_blocks = nn.ModuleList(blocks) 391 392 self.to_img = nn.Sequential( 393 nn.Linear(embedded_dim, patch_size*patch_size*output_channels) 394 ) 395 396 self.norm = nn.LayerNorm(embedded_dim) 397 398 self.refinement = CNNRefinement(input_channels=output_channels, hidden_channels=64, output_channels=output_channels, skip_connection=True)
Init of the PhysicFormer model.
Parameter:
- input_channels (int): Number of input image channels (e.g., 1 for grayscale, 3 for RGB).
- output_channels (int): Number of output image channels.
- img_size (int): Size (height and width) of the input image in pixels.
- patch_size (int): Size of each square patch to split the image into. The image must be divisible by this size.
- embedded_dim (int): Dimension of the patch embedding (feature space size per token).
- num_blocks (int): Number of Transformer Encoder blocks.
- heads (int): Number of attention heads per Attention layer.
- mlp_dim (int): Hidden dimension of the feed-forward (MLP) layers within each Transformer block.
- dropout (float): Dropout probability for regularization applied after positional encoding and inside MLP.
401 def get_input_channels(self): 402 """ 403 Returns the number of input channels used by the model. 404 405 Returns: 406 - int: 407 Number of input channels expected by the model. 408 """ 409 return self.input_channels
Returns the number of input channels used by the model.
Returns:
- int: Number of input channels expected by the model.
411 def get_output_channels(self): 412 """ 413 Returns the number of output channels produced by the model. 414 415 Returns: 416 - int: 417 Number of output channels the model generates 418 """ 419 return self.output_channels
Returns the number of output channels produced by the model.
Returns:
- int: Number of output channels the model generates
422 def forward(self, x): 423 """ 424 Forward pass of the PhysicFormer network. 425 426 Parameter: 427 - x (torch.tensor): 428 Input image tensor of shape (batch_size, input_channels, height, width). 429 430 Returns: 431 - torch.tensor: 432 Refined output image tensor of shape (batch_size, output_channels, height, width), 433 with values normalized to [0.0, 1.0]. 434 435 Notes: 436 - The output passes through a `sigmoid()` activation, ensuring all pixel values ∈ [0, 1]. 437 - Designed for physics-informed or visual reconstruction tasks where local and global consistency are important. 438 """ 439 x_input = x 440 441 # patch embedding / tokenization 442 x, (height, width) = self.patch_embedding(x) 443 444 # encoding / add positional information 445 x = self.positional_encoding(x) 446 447 x = self.dropout(x) 448 449 # transformer blocks 450 for transformer_block in self.transformer_blocks: 451 x = transformer_block(x) 452 453 x = self.norm(x) 454 455 # translation to image 456 x = self.to_img(x) 457 458 # return it in the right format: [B, C, H, W] 459 x = x.transpose(1, 2).reshape(x.shape[0], self.output_channels, self.patch_size*height, self.patch_size*width) 460 # when you call .view() right after .transpose(), PyTorch can’t reinterpret the data layout -> this is an error. 461 462 # refinement 463 x = self.refinement(x, original=x_input) 464 465 # Other version: 466 # refined = self.refinement(x) 467 468 # # Combine contributions (global + local + input) 469 # x = x + refined + x_input 470 471 return torch.sigmoid(x) # between 0.0 and 1.0 -> alt: torch.clamp(x, 0.0, 1.0)
Forward pass of the PhysicFormer network.
Parameter:
- x (torch.tensor): Input image tensor of shape (batch_size, input_channels, height, width).
Returns:
- torch.tensor: Refined output image tensor of shape (batch_size, output_channels, height, width), with values normalized to [0.0, 1.0].
Notes:
- The output passes through a
sigmoid()activation, ensuring all pixel values ∈ [0, 1]. - Designed for physics-informed or visual reconstruction tasks where local and global consistency are important.