image_to_image.models.transformer

Module to define basic Transformer parts.
Also defines a Transformer model for image-to-image tasks, the PhysicsFormer.

Classes:

  • PatchEmbedding
  • PositionalEncoding
  • Attention
  • TransformerEncoderBlock
  • PhysicsFormer

By Tobia Ippolito

  1"""
  2Module to define basic Transformer parts.<br>
  3Also defines a Transformer model for image-to-image tasks, the PhysicsFormer.
  4
  5Classes:
  6- PatchEmbedding
  7- PositionalEncoding
  8- Attention
  9- TransformerEncoderBlock
 10- PhysicsFormer
 11
 12By Tobia Ippolito
 13"""
 14# ---------------------------
 15#        > Imports <
 16# ---------------------------
 17import torch
 18import torch.nn as nn
 19import torch.nn.functional as F
 20
 21
 22
 23# ---------------------------
 24#        > Patching <
 25# ---------------------------
 26class PatchEmbedding(nn.Module):
 27    """
 28    Converts an Image into Patches ('Tokens' in NLP).
 29
 30    Image size must be: H x W x C
 31
 32    Patch size must be: P x P 
 33    """
 34    def __init__(self, img_size, patch_size=16, input_channels=1, embedded_dim=768):
 35        """
 36        Init of Patching.
 37
 38        Each filter creates one new value for each patch 
 39        and this with embedded_dim-filters. <br>
 40        So one patch is projected into a 'embedded_dim'-vector. <br>
 41        For example the value at [0, 0] on each channel in the embedded image is together 
 42        the embedded vector of the first patch.
 43
 44        Parameter:
 45        - img_size (int): 
 46            Size (width or height) of your image. Your image must have the same width and height.
 47        - patch_size (int, default=16): 
 48            Size (width or height) of one patch.
 49        - input_channels (int): 
 50            Number of Input Channels to be expected.
 51        - embedded_dim (int): 
 52            Output channels / channels of the embedding.
 53        """
 54        super().__init__()
 55        self.img_size = img_size
 56        self.patch_size = patch_size
 57        self.in_channels = input_channels
 58        self.embedded_dim = embedded_dim
 59
 60        self.projection = nn.Conv2d(input_channels, embedded_dim, kernel_size=patch_size, stride=patch_size)
 61
 62    def forward(self, x):
 63        """
 64        Forward pass of patching.
 65
 66        Parameter:
 67        - x (torch.tensor): 
 68            Input Image(s).
 69
 70        Returns:
 71        - tuple(torch.tensor, tuple(int, int)): 
 72            The Embedded image with the height and width.
 73        """
 74        # Input:                     (batch_size, in_channels, height, width)
 75        x = self.projection(x)     # (batch_size, embedded_size, num_patches/2, num_patches/2)
 76        B, C, H, W = x.shape
 77        x = x.flatten(2)           # (batch_size, embedded_size, num_patches)
 78        x = x.transpose(1, 2)   # (batch_size, num_patches, embedded_size)
 79        return x, (H, W)
 80
 81
 82
 83# ---------------------------
 84#   > Positional Encoding <
 85# ---------------------------
 86class PositionalEncoding(nn.Module):
 87    """
 88    Add learnable parameters which adds positional information of the patches 
 89    the position of a patch is important, because a picture makes only sense 
 90    if the other of sub pictures (patches) is right.
 91    """
 92    def __init__(self, num_patches, embedded_dim=768):
 93        """
 94        Init of Positonal Encoding.
 95
 96        Parameter:
 97        - num_patches (int): 
 98            Amount of Patches ('Tokens').
 99        - embedded_dim (int, default=768): 
100            Get amount of the embedding channels.
101        """
102        super().__init__()
103        self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches, embedded_dim))
104
105    def forward(self, x):
106        """
107        Forward pass of positional information adding.
108
109        Parameter:
110        - x (torch.tensor): 
111            Patch Embedded Image(s).
112
113        Returns:
114        - torch.tensor: 
115            The Embedded image(s) with positional encoding added.
116        """
117        # use broadcast addition
118        return x + self.positional_embedding
119
120
121
122# ---------------------------
123#       > Attention <
124# ---------------------------
125class Attention(nn.Module):
126    """
127    Basic element of Transformer are the attention-layer. <br>
128    Attention layer computes relations to all patches. <br>
129    This is done by calculating the similarity between 2 learnable vectors ! and K.
130    """
131    def __init__(self, embedded_dim, num_heads):
132        """
133        Init of Attention Layer.
134
135        Parameter:
136        - embedded_dim (int): 
137            Patch Embedding Channels.
138        - num_heads (int): 
139            Number of parallel attention computations.
140        """
141        super().__init__()
142        assert embedded_dim % num_heads == 0, \
143            f"embedded_dim ({embedded_dim}) must be divisible by num_heads ({num_heads})"
144
145        self.num_heads = num_heads
146        self.head_dim = embedded_dim // num_heads
147        self.scale = self.head_dim**-0.5  # factor for scaling
148        self.qkv = nn.Linear(embedded_dim, embedded_dim*3)
149        self.fc = nn.Linear(embedded_dim, embedded_dim)
150
151    def forward(self, x):
152        """
153        Forward pass of Attention Layer.
154
155        softmax(QK^T)V
156
157        Parameter:
158        - x (torch.tensor): 
159            Patch Embedded Image(s) with positional encoding added.
160
161        Returns:
162        - torch.tensor: 
163            The attention cores passed through fully connected layer.
164        """
165        batch_size, num_patches, embedded_dim = x.shape
166        qkv = self.qkv(x)  # (batch_size, num_patches, embedded_dim*3)
167        qkv = qkv.reshape(batch_size, num_patches, 3, self.num_heads, self.head_dim)
168        qkv = qkv.permute(2, 0, 3, 1, 4)  # -(3, batch_size, num_heads, num_patches, head_dim)
169        q, k, v = qkv.unbind(dim=0)
170
171        # compute scaled dot-product
172        # attention_scores = (q@k.transpose(-2, -1)) * self.scale  # (batch_size, num_heads, num_patches, num_patches)
173        # attention_weights = attention_scores.softmax(dim=-1)
174        # attention_output = (attention_weights@v).reshape(batch_size, num_patches, embedded_dim)
175
176        # -> Memory-efficient attention score version (uses flash-attention when available)
177        attention_output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
178        attention_output = attention_output.transpose(1, 2).reshape(
179            batch_size, num_patches, embedded_dim
180        )
181
182        return self.fc(attention_output)
183
184
185
186# ---------------------------
187# > Transformer Encoder Block <
188# ---------------------------
189class TransformerEncoderBlock(nn.Module):
190    """
191    A Transformer Encoder Block consists of self-attention layer 
192    followed by a feedforward layer (mlp = multilayerperceptron) 
193    + (layer) normalization 
194    and with residual connections / skip connections.
195    """
196    def __init__(self, embedded_dim, num_heads, mlp_dim, dropout=0.1):
197        """
198        Init of a Transformer Block.
199
200        Parameter:
201        - embedded_dim (int): 
202            Patch Embedding Channels.
203        - num_heads (int): 
204            Number of parallel attention computations.
205        - mlp_dim (int): 
206            Hidden/Feature dimension of the multi layer perceptron layer.
207        - dropout (float): 
208            Propability of droput regulization.
209        """
210        super().__init__()
211        self.norm_1 = nn.LayerNorm(embedded_dim)
212        self.attention = Attention(embedded_dim=embedded_dim, num_heads=num_heads)
213
214        self.norm_2 = nn.LayerNorm(embedded_dim)
215        # hidden_dim = int(embedded_dim * mlp_ratio)
216        self.mlp = nn.Sequential(
217            nn.Linear(embedded_dim, mlp_dim),
218            nn.GELU(),
219            nn.Linear(mlp_dim, embedded_dim),
220            nn.Dropout(dropout)
221        )
222
223    
224    def forward(self, x):
225        """
226        Forward pass of Transformer Block.
227
228        Parameter:
229        - x (torch.tensor): 
230            Patch Embedded Image(s) with positional encoding added.
231
232        Returns:
233        - torch.tensor: 
234            The attention cores passed through fully connected layer and a multilayer perceptron with layer normalization.
235        """
236        # self attention with skip connection
237        x = self.norm_1(x)
238        x = x + self.attention(x)
239
240        # MLP with skip connection
241        x = self.norm_2(x)
242        x = x + self.mlp(x)
243
244        return x
245
246
247
248# ---------------------------
249#     > CNN Refinement <
250# ---------------------------
251class CNNRefinement(nn.Module):
252    """
253    Refinement Network to remove transformer artefacts.
254    """
255    def __init__(self, input_channels=1, hidden_channels=64, output_channels=1, skip_connection=True):
256        """
257        Init of a CNN Refinement network.
258
259        Parameter:
260        - input_channels (int): 
261            Number of input channels.
262        - hidden_channels (int): 
263            Number of hidden/feature channels.
264        - output_channels (int): 
265            Number of output channels.
266        - skip_connection (bool): 
267            Should a skip connection be used? Means if a second input (the original image) should be added to the output. 
268            That changes the CNN network to learning a correction which will be applied to the original image.
269        """
270        super().__init__()
271
272        self.conv_1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=3, stride=1, padding=1)
273        self.activation_1 = nn.ReLU()
274        
275        self.conv_2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1)
276        self.activation_2 = nn.ReLU()
277
278        self.conv_3 = nn.Conv2d(hidden_channels, output_channels, kernel_size=3, stride=1, padding=1)
279
280        # or just:
281        # nn.Sequential(
282        #     nn.Conv2d(output_channels, 64, 3, padding=1),
283        #     nn.ReLU(),
284        #     nn.Conv2d(64, 64, 3, padding=1),
285        #     nn.ReLU(),
286        #     nn.Conv2d(64, output_channels, 3, padding=1)
287        # )
288
289        self.skip_connection = skip_connection
290
291    
292    def forward(self, x, original=None):
293        """
294        Forward pass of CNN Refinement network.
295
296        Parameter:
297        - x (torch.tensor): 
298            Patch Embedded Image(s) with positional encoding added.
299        - original (torch.tensor or None, default=None): 
300            Original image which will be added at the end if skip_connection is set to true.
301
302        Returns:
303        - torch.tensor: 
304            The refined image.
305        """
306        x = self.conv_1(x)
307        x = self.activation_1(x)
308
309        x = self.conv_2(x)
310        x = self.activation_2(x)
311
312        x = self.conv_3(x)
313
314        if self.skip_connection and original is not None:
315            # apply correction to image
316            x = x + original
317
318        return x
319
320# ---------------------------
321#   > Img2Img Transformer <
322# ---------------------------
323# The whole model consists of:
324#   - patching (tokenizing)
325#   - add positional encoding
326#   - Transformer Blocks
327class PhysicFormer(nn.Module):
328    """
329    Image-to-Image Transformer.
330
331    The whole model consists of:
332    - Patching (tokenizing)
333    - Add positional encoding
334    - Transformer Blocks (Attention + MLP)
335    - Image Reconstruction/Remapping -> Embedded Space to Pixel Space
336    - CNN Refinement
337
338    Model logic:<br>
339    - `CNN(Transformer(x))` = Pure Transformation (skip connection = false)
340    - `CNN(Transformer(x)) + x_input` = residual refinement (skip connection = true)
341    - `Transformer(x) + CNN(Transformer(x)) + x_input` = global field + local correction + geometry/residual (not available yet)
342    """
343    def __init__(self, input_channels=1, output_channels=1, 
344                 img_size=256, patch_size=4, 
345                 embedded_dim=1026, num_blocks=8,
346                 heads=16, mlp_dim=2048, dropout=0.1):
347        """
348        Init of the PhysicFormer model.
349
350        Parameter:
351        - input_channels (int): 
352            Number of input image channels (e.g., 1 for grayscale, 3 for RGB).
353        - output_channels (int): 
354            Number of output image channels.
355        - img_size (int): 
356            Size (height and width) of the input image in pixels.
357        - patch_size (int): 
358            Size of each square patch to split the image into. 
359            The image must be divisible by this size.
360        - embedded_dim (int): 
361            Dimension of the patch embedding (feature space size per token).
362        - num_blocks (int): 
363            Number of Transformer Encoder blocks.
364        - heads (int): 
365            Number of attention heads per Attention layer.
366        - mlp_dim (int): 
367            Hidden dimension of the feed-forward (MLP) layers within each Transformer block.
368        - dropout (float): 
369            Dropout probability for regularization applied after positional encoding and inside MLP.
370        """
371        super().__init__()
372        self.input_channels = input_channels
373        self.output_channels = output_channels
374        self.patch_size = patch_size
375
376        self.patch_embedding = PatchEmbedding(img_size=img_size,
377                                              patch_size=patch_size,
378                                              input_channels=input_channels,
379                                              embedded_dim=embedded_dim)
380
381        num_patches = (img_size//patch_size) * (img_size//patch_size)
382        self.positional_encoding = PositionalEncoding(num_patches=num_patches, embedded_dim=embedded_dim)
383
384        self.dropout = nn.Dropout(dropout)
385
386        blocks = []
387        for _ in range(num_blocks):
388            blocks += [TransformerEncoderBlock(embedded_dim=embedded_dim, num_heads=heads, mlp_dim=mlp_dim, dropout=dropout)]
389        self.transformer_blocks = nn.ModuleList(blocks)
390
391        self.to_img = nn.Sequential(
392            nn.Linear(embedded_dim, patch_size*patch_size*output_channels)
393        )
394
395        self.norm = nn.LayerNorm(embedded_dim)
396
397        self.refinement = CNNRefinement(input_channels=output_channels, hidden_channels=64, output_channels=output_channels, skip_connection=True)
398
399
400    def get_input_channels(self):
401        """
402        Returns the number of input channels used by the model.
403
404        Returns:
405        - int: 
406            Number of input channels expected by the model.
407        """
408        return self.input_channels
409
410    def get_output_channels(self):
411        """
412        Returns the number of output channels produced by the model.
413
414        Returns:
415        - int: 
416            Number of output channels the model generates
417        """
418        return self.output_channels
419
420
421    def forward(self, x):
422        """
423        Forward pass of the PhysicFormer network.
424
425        Parameter:
426        - x (torch.tensor): 
427            Input image tensor of shape (batch_size, input_channels, height, width).
428
429        Returns:
430        - torch.tensor: 
431            Refined output image tensor of shape (batch_size, output_channels, height, width), 
432            with values normalized to [0.0, 1.0].
433
434        Notes:
435        - The output passes through a `sigmoid()` activation, ensuring all pixel values ∈ [0, 1].
436        - Designed for physics-informed or visual reconstruction tasks where local and global consistency are important.
437        """
438        x_input = x
439
440        # patch embedding / tokenization
441        x, (height, width) = self.patch_embedding(x)
442
443        # encoding / add positional information
444        x = self.positional_encoding(x)
445
446        x = self.dropout(x)
447        
448        # transformer blocks
449        for transformer_block in self.transformer_blocks:
450            x = transformer_block(x)
451
452        x = self.norm(x)
453
454        # translation to image
455        x = self.to_img(x)
456
457        # return it in the right format: [B, C, H, W]
458        x = x.transpose(1, 2).reshape(x.shape[0], self.output_channels, self.patch_size*height, self.patch_size*width)
459        # when you call .view() right after .transpose(), PyTorch can’t reinterpret the data layout -> this is an error.
460
461        # refinement
462        x = self.refinement(x, original=x_input)
463
464        # Other version:
465        # refined = self.refinement(x)
466
467        # # Combine contributions (global + local + input)
468        # x = x + refined + x_input
469
470        return torch.sigmoid(x)  # between 0.0 and 1.0 -> alt: torch.clamp(x, 0.0, 1.0)
class PatchEmbedding(torch.nn.modules.module.Module):
27class PatchEmbedding(nn.Module):
28    """
29    Converts an Image into Patches ('Tokens' in NLP).
30
31    Image size must be: H x W x C
32
33    Patch size must be: P x P 
34    """
35    def __init__(self, img_size, patch_size=16, input_channels=1, embedded_dim=768):
36        """
37        Init of Patching.
38
39        Each filter creates one new value for each patch 
40        and this with embedded_dim-filters. <br>
41        So one patch is projected into a 'embedded_dim'-vector. <br>
42        For example the value at [0, 0] on each channel in the embedded image is together 
43        the embedded vector of the first patch.
44
45        Parameter:
46        - img_size (int): 
47            Size (width or height) of your image. Your image must have the same width and height.
48        - patch_size (int, default=16): 
49            Size (width or height) of one patch.
50        - input_channels (int): 
51            Number of Input Channels to be expected.
52        - embedded_dim (int): 
53            Output channels / channels of the embedding.
54        """
55        super().__init__()
56        self.img_size = img_size
57        self.patch_size = patch_size
58        self.in_channels = input_channels
59        self.embedded_dim = embedded_dim
60
61        self.projection = nn.Conv2d(input_channels, embedded_dim, kernel_size=patch_size, stride=patch_size)
62
63    def forward(self, x):
64        """
65        Forward pass of patching.
66
67        Parameter:
68        - x (torch.tensor): 
69            Input Image(s).
70
71        Returns:
72        - tuple(torch.tensor, tuple(int, int)): 
73            The Embedded image with the height and width.
74        """
75        # Input:                     (batch_size, in_channels, height, width)
76        x = self.projection(x)     # (batch_size, embedded_size, num_patches/2, num_patches/2)
77        B, C, H, W = x.shape
78        x = x.flatten(2)           # (batch_size, embedded_size, num_patches)
79        x = x.transpose(1, 2)   # (batch_size, num_patches, embedded_size)
80        return x, (H, W)

Converts an Image into Patches ('Tokens' in NLP).

Image size must be: H x W x C

Patch size must be: P x P

PatchEmbedding(img_size, patch_size=16, input_channels=1, embedded_dim=768)
35    def __init__(self, img_size, patch_size=16, input_channels=1, embedded_dim=768):
36        """
37        Init of Patching.
38
39        Each filter creates one new value for each patch 
40        and this with embedded_dim-filters. <br>
41        So one patch is projected into a 'embedded_dim'-vector. <br>
42        For example the value at [0, 0] on each channel in the embedded image is together 
43        the embedded vector of the first patch.
44
45        Parameter:
46        - img_size (int): 
47            Size (width or height) of your image. Your image must have the same width and height.
48        - patch_size (int, default=16): 
49            Size (width or height) of one patch.
50        - input_channels (int): 
51            Number of Input Channels to be expected.
52        - embedded_dim (int): 
53            Output channels / channels of the embedding.
54        """
55        super().__init__()
56        self.img_size = img_size
57        self.patch_size = patch_size
58        self.in_channels = input_channels
59        self.embedded_dim = embedded_dim
60
61        self.projection = nn.Conv2d(input_channels, embedded_dim, kernel_size=patch_size, stride=patch_size)

Init of Patching.

Each filter creates one new value for each patch and this with embedded_dim-filters.
So one patch is projected into a 'embedded_dim'-vector.
For example the value at [0, 0] on each channel in the embedded image is together the embedded vector of the first patch.

Parameter:

  • img_size (int): Size (width or height) of your image. Your image must have the same width and height.
  • patch_size (int, default=16): Size (width or height) of one patch.
  • input_channels (int): Number of Input Channels to be expected.
  • embedded_dim (int): Output channels / channels of the embedding.
img_size
patch_size
in_channels
embedded_dim
projection
def forward(self, x):
63    def forward(self, x):
64        """
65        Forward pass of patching.
66
67        Parameter:
68        - x (torch.tensor): 
69            Input Image(s).
70
71        Returns:
72        - tuple(torch.tensor, tuple(int, int)): 
73            The Embedded image with the height and width.
74        """
75        # Input:                     (batch_size, in_channels, height, width)
76        x = self.projection(x)     # (batch_size, embedded_size, num_patches/2, num_patches/2)
77        B, C, H, W = x.shape
78        x = x.flatten(2)           # (batch_size, embedded_size, num_patches)
79        x = x.transpose(1, 2)   # (batch_size, num_patches, embedded_size)
80        return x, (H, W)

Forward pass of patching.

Parameter:

  • x (torch.tensor): Input Image(s).

Returns:

  • tuple(torch.tensor, tuple(int, int)): The Embedded image with the height and width.
class PositionalEncoding(torch.nn.modules.module.Module):
 87class PositionalEncoding(nn.Module):
 88    """
 89    Add learnable parameters which adds positional information of the patches 
 90    the position of a patch is important, because a picture makes only sense 
 91    if the other of sub pictures (patches) is right.
 92    """
 93    def __init__(self, num_patches, embedded_dim=768):
 94        """
 95        Init of Positonal Encoding.
 96
 97        Parameter:
 98        - num_patches (int): 
 99            Amount of Patches ('Tokens').
100        - embedded_dim (int, default=768): 
101            Get amount of the embedding channels.
102        """
103        super().__init__()
104        self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches, embedded_dim))
105
106    def forward(self, x):
107        """
108        Forward pass of positional information adding.
109
110        Parameter:
111        - x (torch.tensor): 
112            Patch Embedded Image(s).
113
114        Returns:
115        - torch.tensor: 
116            The Embedded image(s) with positional encoding added.
117        """
118        # use broadcast addition
119        return x + self.positional_embedding

Add learnable parameters which adds positional information of the patches the position of a patch is important, because a picture makes only sense if the other of sub pictures (patches) is right.

PositionalEncoding(num_patches, embedded_dim=768)
 93    def __init__(self, num_patches, embedded_dim=768):
 94        """
 95        Init of Positonal Encoding.
 96
 97        Parameter:
 98        - num_patches (int): 
 99            Amount of Patches ('Tokens').
100        - embedded_dim (int, default=768): 
101            Get amount of the embedding channels.
102        """
103        super().__init__()
104        self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches, embedded_dim))

Init of Positonal Encoding.

Parameter:

  • num_patches (int): Amount of Patches ('Tokens').
  • embedded_dim (int, default=768): Get amount of the embedding channels.
positional_embedding
def forward(self, x):
106    def forward(self, x):
107        """
108        Forward pass of positional information adding.
109
110        Parameter:
111        - x (torch.tensor): 
112            Patch Embedded Image(s).
113
114        Returns:
115        - torch.tensor: 
116            The Embedded image(s) with positional encoding added.
117        """
118        # use broadcast addition
119        return x + self.positional_embedding

Forward pass of positional information adding.

Parameter:

  • x (torch.tensor): Patch Embedded Image(s).

Returns:

  • torch.tensor: The Embedded image(s) with positional encoding added.
class Attention(torch.nn.modules.module.Module):
126class Attention(nn.Module):
127    """
128    Basic element of Transformer are the attention-layer. <br>
129    Attention layer computes relations to all patches. <br>
130    This is done by calculating the similarity between 2 learnable vectors ! and K.
131    """
132    def __init__(self, embedded_dim, num_heads):
133        """
134        Init of Attention Layer.
135
136        Parameter:
137        - embedded_dim (int): 
138            Patch Embedding Channels.
139        - num_heads (int): 
140            Number of parallel attention computations.
141        """
142        super().__init__()
143        assert embedded_dim % num_heads == 0, \
144            f"embedded_dim ({embedded_dim}) must be divisible by num_heads ({num_heads})"
145
146        self.num_heads = num_heads
147        self.head_dim = embedded_dim // num_heads
148        self.scale = self.head_dim**-0.5  # factor for scaling
149        self.qkv = nn.Linear(embedded_dim, embedded_dim*3)
150        self.fc = nn.Linear(embedded_dim, embedded_dim)
151
152    def forward(self, x):
153        """
154        Forward pass of Attention Layer.
155
156        softmax(QK^T)V
157
158        Parameter:
159        - x (torch.tensor): 
160            Patch Embedded Image(s) with positional encoding added.
161
162        Returns:
163        - torch.tensor: 
164            The attention cores passed through fully connected layer.
165        """
166        batch_size, num_patches, embedded_dim = x.shape
167        qkv = self.qkv(x)  # (batch_size, num_patches, embedded_dim*3)
168        qkv = qkv.reshape(batch_size, num_patches, 3, self.num_heads, self.head_dim)
169        qkv = qkv.permute(2, 0, 3, 1, 4)  # -(3, batch_size, num_heads, num_patches, head_dim)
170        q, k, v = qkv.unbind(dim=0)
171
172        # compute scaled dot-product
173        # attention_scores = (q@k.transpose(-2, -1)) * self.scale  # (batch_size, num_heads, num_patches, num_patches)
174        # attention_weights = attention_scores.softmax(dim=-1)
175        # attention_output = (attention_weights@v).reshape(batch_size, num_patches, embedded_dim)
176
177        # -> Memory-efficient attention score version (uses flash-attention when available)
178        attention_output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
179        attention_output = attention_output.transpose(1, 2).reshape(
180            batch_size, num_patches, embedded_dim
181        )
182
183        return self.fc(attention_output)

Basic element of Transformer are the attention-layer.
Attention layer computes relations to all patches.
This is done by calculating the similarity between 2 learnable vectors ! and K.

Attention(embedded_dim, num_heads)
132    def __init__(self, embedded_dim, num_heads):
133        """
134        Init of Attention Layer.
135
136        Parameter:
137        - embedded_dim (int): 
138            Patch Embedding Channels.
139        - num_heads (int): 
140            Number of parallel attention computations.
141        """
142        super().__init__()
143        assert embedded_dim % num_heads == 0, \
144            f"embedded_dim ({embedded_dim}) must be divisible by num_heads ({num_heads})"
145
146        self.num_heads = num_heads
147        self.head_dim = embedded_dim // num_heads
148        self.scale = self.head_dim**-0.5  # factor for scaling
149        self.qkv = nn.Linear(embedded_dim, embedded_dim*3)
150        self.fc = nn.Linear(embedded_dim, embedded_dim)

Init of Attention Layer.

Parameter:

  • embedded_dim (int): Patch Embedding Channels.
  • num_heads (int): Number of parallel attention computations.
num_heads
head_dim
scale
qkv
fc
def forward(self, x):
152    def forward(self, x):
153        """
154        Forward pass of Attention Layer.
155
156        softmax(QK^T)V
157
158        Parameter:
159        - x (torch.tensor): 
160            Patch Embedded Image(s) with positional encoding added.
161
162        Returns:
163        - torch.tensor: 
164            The attention cores passed through fully connected layer.
165        """
166        batch_size, num_patches, embedded_dim = x.shape
167        qkv = self.qkv(x)  # (batch_size, num_patches, embedded_dim*3)
168        qkv = qkv.reshape(batch_size, num_patches, 3, self.num_heads, self.head_dim)
169        qkv = qkv.permute(2, 0, 3, 1, 4)  # -(3, batch_size, num_heads, num_patches, head_dim)
170        q, k, v = qkv.unbind(dim=0)
171
172        # compute scaled dot-product
173        # attention_scores = (q@k.transpose(-2, -1)) * self.scale  # (batch_size, num_heads, num_patches, num_patches)
174        # attention_weights = attention_scores.softmax(dim=-1)
175        # attention_output = (attention_weights@v).reshape(batch_size, num_patches, embedded_dim)
176
177        # -> Memory-efficient attention score version (uses flash-attention when available)
178        attention_output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
179        attention_output = attention_output.transpose(1, 2).reshape(
180            batch_size, num_patches, embedded_dim
181        )
182
183        return self.fc(attention_output)

Forward pass of Attention Layer.

softmax(QK^T)V

Parameter:

  • x (torch.tensor): Patch Embedded Image(s) with positional encoding added.

Returns:

  • torch.tensor: The attention cores passed through fully connected layer.
class TransformerEncoderBlock(torch.nn.modules.module.Module):
190class TransformerEncoderBlock(nn.Module):
191    """
192    A Transformer Encoder Block consists of self-attention layer 
193    followed by a feedforward layer (mlp = multilayerperceptron) 
194    + (layer) normalization 
195    and with residual connections / skip connections.
196    """
197    def __init__(self, embedded_dim, num_heads, mlp_dim, dropout=0.1):
198        """
199        Init of a Transformer Block.
200
201        Parameter:
202        - embedded_dim (int): 
203            Patch Embedding Channels.
204        - num_heads (int): 
205            Number of parallel attention computations.
206        - mlp_dim (int): 
207            Hidden/Feature dimension of the multi layer perceptron layer.
208        - dropout (float): 
209            Propability of droput regulization.
210        """
211        super().__init__()
212        self.norm_1 = nn.LayerNorm(embedded_dim)
213        self.attention = Attention(embedded_dim=embedded_dim, num_heads=num_heads)
214
215        self.norm_2 = nn.LayerNorm(embedded_dim)
216        # hidden_dim = int(embedded_dim * mlp_ratio)
217        self.mlp = nn.Sequential(
218            nn.Linear(embedded_dim, mlp_dim),
219            nn.GELU(),
220            nn.Linear(mlp_dim, embedded_dim),
221            nn.Dropout(dropout)
222        )
223
224    
225    def forward(self, x):
226        """
227        Forward pass of Transformer Block.
228
229        Parameter:
230        - x (torch.tensor): 
231            Patch Embedded Image(s) with positional encoding added.
232
233        Returns:
234        - torch.tensor: 
235            The attention cores passed through fully connected layer and a multilayer perceptron with layer normalization.
236        """
237        # self attention with skip connection
238        x = self.norm_1(x)
239        x = x + self.attention(x)
240
241        # MLP with skip connection
242        x = self.norm_2(x)
243        x = x + self.mlp(x)
244
245        return x

A Transformer Encoder Block consists of self-attention layer followed by a feedforward layer (mlp = multilayerperceptron)

  • (layer) normalization and with residual connections / skip connections.
TransformerEncoderBlock(embedded_dim, num_heads, mlp_dim, dropout=0.1)
197    def __init__(self, embedded_dim, num_heads, mlp_dim, dropout=0.1):
198        """
199        Init of a Transformer Block.
200
201        Parameter:
202        - embedded_dim (int): 
203            Patch Embedding Channels.
204        - num_heads (int): 
205            Number of parallel attention computations.
206        - mlp_dim (int): 
207            Hidden/Feature dimension of the multi layer perceptron layer.
208        - dropout (float): 
209            Propability of droput regulization.
210        """
211        super().__init__()
212        self.norm_1 = nn.LayerNorm(embedded_dim)
213        self.attention = Attention(embedded_dim=embedded_dim, num_heads=num_heads)
214
215        self.norm_2 = nn.LayerNorm(embedded_dim)
216        # hidden_dim = int(embedded_dim * mlp_ratio)
217        self.mlp = nn.Sequential(
218            nn.Linear(embedded_dim, mlp_dim),
219            nn.GELU(),
220            nn.Linear(mlp_dim, embedded_dim),
221            nn.Dropout(dropout)
222        )

Init of a Transformer Block.

Parameter:

  • embedded_dim (int): Patch Embedding Channels.
  • num_heads (int): Number of parallel attention computations.
  • mlp_dim (int): Hidden/Feature dimension of the multi layer perceptron layer.
  • dropout (float): Propability of droput regulization.
norm_1
attention
norm_2
mlp
def forward(self, x):
225    def forward(self, x):
226        """
227        Forward pass of Transformer Block.
228
229        Parameter:
230        - x (torch.tensor): 
231            Patch Embedded Image(s) with positional encoding added.
232
233        Returns:
234        - torch.tensor: 
235            The attention cores passed through fully connected layer and a multilayer perceptron with layer normalization.
236        """
237        # self attention with skip connection
238        x = self.norm_1(x)
239        x = x + self.attention(x)
240
241        # MLP with skip connection
242        x = self.norm_2(x)
243        x = x + self.mlp(x)
244
245        return x

Forward pass of Transformer Block.

Parameter:

  • x (torch.tensor): Patch Embedded Image(s) with positional encoding added.

Returns:

  • torch.tensor: The attention cores passed through fully connected layer and a multilayer perceptron with layer normalization.
class CNNRefinement(torch.nn.modules.module.Module):
252class CNNRefinement(nn.Module):
253    """
254    Refinement Network to remove transformer artefacts.
255    """
256    def __init__(self, input_channels=1, hidden_channels=64, output_channels=1, skip_connection=True):
257        """
258        Init of a CNN Refinement network.
259
260        Parameter:
261        - input_channels (int): 
262            Number of input channels.
263        - hidden_channels (int): 
264            Number of hidden/feature channels.
265        - output_channels (int): 
266            Number of output channels.
267        - skip_connection (bool): 
268            Should a skip connection be used? Means if a second input (the original image) should be added to the output. 
269            That changes the CNN network to learning a correction which will be applied to the original image.
270        """
271        super().__init__()
272
273        self.conv_1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=3, stride=1, padding=1)
274        self.activation_1 = nn.ReLU()
275        
276        self.conv_2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1)
277        self.activation_2 = nn.ReLU()
278
279        self.conv_3 = nn.Conv2d(hidden_channels, output_channels, kernel_size=3, stride=1, padding=1)
280
281        # or just:
282        # nn.Sequential(
283        #     nn.Conv2d(output_channels, 64, 3, padding=1),
284        #     nn.ReLU(),
285        #     nn.Conv2d(64, 64, 3, padding=1),
286        #     nn.ReLU(),
287        #     nn.Conv2d(64, output_channels, 3, padding=1)
288        # )
289
290        self.skip_connection = skip_connection
291
292    
293    def forward(self, x, original=None):
294        """
295        Forward pass of CNN Refinement network.
296
297        Parameter:
298        - x (torch.tensor): 
299            Patch Embedded Image(s) with positional encoding added.
300        - original (torch.tensor or None, default=None): 
301            Original image which will be added at the end if skip_connection is set to true.
302
303        Returns:
304        - torch.tensor: 
305            The refined image.
306        """
307        x = self.conv_1(x)
308        x = self.activation_1(x)
309
310        x = self.conv_2(x)
311        x = self.activation_2(x)
312
313        x = self.conv_3(x)
314
315        if self.skip_connection and original is not None:
316            # apply correction to image
317            x = x + original
318
319        return x

Refinement Network to remove transformer artefacts.

CNNRefinement( input_channels=1, hidden_channels=64, output_channels=1, skip_connection=True)
256    def __init__(self, input_channels=1, hidden_channels=64, output_channels=1, skip_connection=True):
257        """
258        Init of a CNN Refinement network.
259
260        Parameter:
261        - input_channels (int): 
262            Number of input channels.
263        - hidden_channels (int): 
264            Number of hidden/feature channels.
265        - output_channels (int): 
266            Number of output channels.
267        - skip_connection (bool): 
268            Should a skip connection be used? Means if a second input (the original image) should be added to the output. 
269            That changes the CNN network to learning a correction which will be applied to the original image.
270        """
271        super().__init__()
272
273        self.conv_1 = nn.Conv2d(input_channels, hidden_channels, kernel_size=3, stride=1, padding=1)
274        self.activation_1 = nn.ReLU()
275        
276        self.conv_2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, stride=1, padding=1)
277        self.activation_2 = nn.ReLU()
278
279        self.conv_3 = nn.Conv2d(hidden_channels, output_channels, kernel_size=3, stride=1, padding=1)
280
281        # or just:
282        # nn.Sequential(
283        #     nn.Conv2d(output_channels, 64, 3, padding=1),
284        #     nn.ReLU(),
285        #     nn.Conv2d(64, 64, 3, padding=1),
286        #     nn.ReLU(),
287        #     nn.Conv2d(64, output_channels, 3, padding=1)
288        # )
289
290        self.skip_connection = skip_connection

Init of a CNN Refinement network.

Parameter:

  • input_channels (int): Number of input channels.
  • hidden_channels (int): Number of hidden/feature channels.
  • output_channels (int): Number of output channels.
  • skip_connection (bool): Should a skip connection be used? Means if a second input (the original image) should be added to the output. That changes the CNN network to learning a correction which will be applied to the original image.
conv_1
activation_1
conv_2
activation_2
conv_3
skip_connection
def forward(self, x, original=None):
293    def forward(self, x, original=None):
294        """
295        Forward pass of CNN Refinement network.
296
297        Parameter:
298        - x (torch.tensor): 
299            Patch Embedded Image(s) with positional encoding added.
300        - original (torch.tensor or None, default=None): 
301            Original image which will be added at the end if skip_connection is set to true.
302
303        Returns:
304        - torch.tensor: 
305            The refined image.
306        """
307        x = self.conv_1(x)
308        x = self.activation_1(x)
309
310        x = self.conv_2(x)
311        x = self.activation_2(x)
312
313        x = self.conv_3(x)
314
315        if self.skip_connection and original is not None:
316            # apply correction to image
317            x = x + original
318
319        return x

Forward pass of CNN Refinement network.

Parameter:

  • x (torch.tensor): Patch Embedded Image(s) with positional encoding added.
  • original (torch.tensor or None, default=None): Original image which will be added at the end if skip_connection is set to true.

Returns:

  • torch.tensor: The refined image.
class PhysicFormer(torch.nn.modules.module.Module):
328class PhysicFormer(nn.Module):
329    """
330    Image-to-Image Transformer.
331
332    The whole model consists of:
333    - Patching (tokenizing)
334    - Add positional encoding
335    - Transformer Blocks (Attention + MLP)
336    - Image Reconstruction/Remapping -> Embedded Space to Pixel Space
337    - CNN Refinement
338
339    Model logic:<br>
340    - `CNN(Transformer(x))` = Pure Transformation (skip connection = false)
341    - `CNN(Transformer(x)) + x_input` = residual refinement (skip connection = true)
342    - `Transformer(x) + CNN(Transformer(x)) + x_input` = global field + local correction + geometry/residual (not available yet)
343    """
344    def __init__(self, input_channels=1, output_channels=1, 
345                 img_size=256, patch_size=4, 
346                 embedded_dim=1026, num_blocks=8,
347                 heads=16, mlp_dim=2048, dropout=0.1):
348        """
349        Init of the PhysicFormer model.
350
351        Parameter:
352        - input_channels (int): 
353            Number of input image channels (e.g., 1 for grayscale, 3 for RGB).
354        - output_channels (int): 
355            Number of output image channels.
356        - img_size (int): 
357            Size (height and width) of the input image in pixels.
358        - patch_size (int): 
359            Size of each square patch to split the image into. 
360            The image must be divisible by this size.
361        - embedded_dim (int): 
362            Dimension of the patch embedding (feature space size per token).
363        - num_blocks (int): 
364            Number of Transformer Encoder blocks.
365        - heads (int): 
366            Number of attention heads per Attention layer.
367        - mlp_dim (int): 
368            Hidden dimension of the feed-forward (MLP) layers within each Transformer block.
369        - dropout (float): 
370            Dropout probability for regularization applied after positional encoding and inside MLP.
371        """
372        super().__init__()
373        self.input_channels = input_channels
374        self.output_channels = output_channels
375        self.patch_size = patch_size
376
377        self.patch_embedding = PatchEmbedding(img_size=img_size,
378                                              patch_size=patch_size,
379                                              input_channels=input_channels,
380                                              embedded_dim=embedded_dim)
381
382        num_patches = (img_size//patch_size) * (img_size//patch_size)
383        self.positional_encoding = PositionalEncoding(num_patches=num_patches, embedded_dim=embedded_dim)
384
385        self.dropout = nn.Dropout(dropout)
386
387        blocks = []
388        for _ in range(num_blocks):
389            blocks += [TransformerEncoderBlock(embedded_dim=embedded_dim, num_heads=heads, mlp_dim=mlp_dim, dropout=dropout)]
390        self.transformer_blocks = nn.ModuleList(blocks)
391
392        self.to_img = nn.Sequential(
393            nn.Linear(embedded_dim, patch_size*patch_size*output_channels)
394        )
395
396        self.norm = nn.LayerNorm(embedded_dim)
397
398        self.refinement = CNNRefinement(input_channels=output_channels, hidden_channels=64, output_channels=output_channels, skip_connection=True)
399
400
401    def get_input_channels(self):
402        """
403        Returns the number of input channels used by the model.
404
405        Returns:
406        - int: 
407            Number of input channels expected by the model.
408        """
409        return self.input_channels
410
411    def get_output_channels(self):
412        """
413        Returns the number of output channels produced by the model.
414
415        Returns:
416        - int: 
417            Number of output channels the model generates
418        """
419        return self.output_channels
420
421
422    def forward(self, x):
423        """
424        Forward pass of the PhysicFormer network.
425
426        Parameter:
427        - x (torch.tensor): 
428            Input image tensor of shape (batch_size, input_channels, height, width).
429
430        Returns:
431        - torch.tensor: 
432            Refined output image tensor of shape (batch_size, output_channels, height, width), 
433            with values normalized to [0.0, 1.0].
434
435        Notes:
436        - The output passes through a `sigmoid()` activation, ensuring all pixel values ∈ [0, 1].
437        - Designed for physics-informed or visual reconstruction tasks where local and global consistency are important.
438        """
439        x_input = x
440
441        # patch embedding / tokenization
442        x, (height, width) = self.patch_embedding(x)
443
444        # encoding / add positional information
445        x = self.positional_encoding(x)
446
447        x = self.dropout(x)
448        
449        # transformer blocks
450        for transformer_block in self.transformer_blocks:
451            x = transformer_block(x)
452
453        x = self.norm(x)
454
455        # translation to image
456        x = self.to_img(x)
457
458        # return it in the right format: [B, C, H, W]
459        x = x.transpose(1, 2).reshape(x.shape[0], self.output_channels, self.patch_size*height, self.patch_size*width)
460        # when you call .view() right after .transpose(), PyTorch can’t reinterpret the data layout -> this is an error.
461
462        # refinement
463        x = self.refinement(x, original=x_input)
464
465        # Other version:
466        # refined = self.refinement(x)
467
468        # # Combine contributions (global + local + input)
469        # x = x + refined + x_input
470
471        return torch.sigmoid(x)  # between 0.0 and 1.0 -> alt: torch.clamp(x, 0.0, 1.0)

Image-to-Image Transformer.

The whole model consists of:

  • Patching (tokenizing)
  • Add positional encoding
  • Transformer Blocks (Attention + MLP)
  • Image Reconstruction/Remapping -> Embedded Space to Pixel Space
  • CNN Refinement

Model logic:

  • CNN(Transformer(x)) = Pure Transformation (skip connection = false)
  • CNN(Transformer(x)) + x_input = residual refinement (skip connection = true)
  • Transformer(x) + CNN(Transformer(x)) + x_input = global field + local correction + geometry/residual (not available yet)
PhysicFormer( input_channels=1, output_channels=1, img_size=256, patch_size=4, embedded_dim=1026, num_blocks=8, heads=16, mlp_dim=2048, dropout=0.1)
344    def __init__(self, input_channels=1, output_channels=1, 
345                 img_size=256, patch_size=4, 
346                 embedded_dim=1026, num_blocks=8,
347                 heads=16, mlp_dim=2048, dropout=0.1):
348        """
349        Init of the PhysicFormer model.
350
351        Parameter:
352        - input_channels (int): 
353            Number of input image channels (e.g., 1 for grayscale, 3 for RGB).
354        - output_channels (int): 
355            Number of output image channels.
356        - img_size (int): 
357            Size (height and width) of the input image in pixels.
358        - patch_size (int): 
359            Size of each square patch to split the image into. 
360            The image must be divisible by this size.
361        - embedded_dim (int): 
362            Dimension of the patch embedding (feature space size per token).
363        - num_blocks (int): 
364            Number of Transformer Encoder blocks.
365        - heads (int): 
366            Number of attention heads per Attention layer.
367        - mlp_dim (int): 
368            Hidden dimension of the feed-forward (MLP) layers within each Transformer block.
369        - dropout (float): 
370            Dropout probability for regularization applied after positional encoding and inside MLP.
371        """
372        super().__init__()
373        self.input_channels = input_channels
374        self.output_channels = output_channels
375        self.patch_size = patch_size
376
377        self.patch_embedding = PatchEmbedding(img_size=img_size,
378                                              patch_size=patch_size,
379                                              input_channels=input_channels,
380                                              embedded_dim=embedded_dim)
381
382        num_patches = (img_size//patch_size) * (img_size//patch_size)
383        self.positional_encoding = PositionalEncoding(num_patches=num_patches, embedded_dim=embedded_dim)
384
385        self.dropout = nn.Dropout(dropout)
386
387        blocks = []
388        for _ in range(num_blocks):
389            blocks += [TransformerEncoderBlock(embedded_dim=embedded_dim, num_heads=heads, mlp_dim=mlp_dim, dropout=dropout)]
390        self.transformer_blocks = nn.ModuleList(blocks)
391
392        self.to_img = nn.Sequential(
393            nn.Linear(embedded_dim, patch_size*patch_size*output_channels)
394        )
395
396        self.norm = nn.LayerNorm(embedded_dim)
397
398        self.refinement = CNNRefinement(input_channels=output_channels, hidden_channels=64, output_channels=output_channels, skip_connection=True)

Init of the PhysicFormer model.

Parameter:

  • input_channels (int): Number of input image channels (e.g., 1 for grayscale, 3 for RGB).
  • output_channels (int): Number of output image channels.
  • img_size (int): Size (height and width) of the input image in pixels.
  • patch_size (int): Size of each square patch to split the image into. The image must be divisible by this size.
  • embedded_dim (int): Dimension of the patch embedding (feature space size per token).
  • num_blocks (int): Number of Transformer Encoder blocks.
  • heads (int): Number of attention heads per Attention layer.
  • mlp_dim (int): Hidden dimension of the feed-forward (MLP) layers within each Transformer block.
  • dropout (float): Dropout probability for regularization applied after positional encoding and inside MLP.
input_channels
output_channels
patch_size
patch_embedding
positional_encoding
dropout
transformer_blocks
to_img
norm
refinement
def get_input_channels(self):
401    def get_input_channels(self):
402        """
403        Returns the number of input channels used by the model.
404
405        Returns:
406        - int: 
407            Number of input channels expected by the model.
408        """
409        return self.input_channels

Returns the number of input channels used by the model.

Returns:

  • int: Number of input channels expected by the model.
def get_output_channels(self):
411    def get_output_channels(self):
412        """
413        Returns the number of output channels produced by the model.
414
415        Returns:
416        - int: 
417            Number of output channels the model generates
418        """
419        return self.output_channels

Returns the number of output channels produced by the model.

Returns:

  • int: Number of output channels the model generates
def forward(self, x):
422    def forward(self, x):
423        """
424        Forward pass of the PhysicFormer network.
425
426        Parameter:
427        - x (torch.tensor): 
428            Input image tensor of shape (batch_size, input_channels, height, width).
429
430        Returns:
431        - torch.tensor: 
432            Refined output image tensor of shape (batch_size, output_channels, height, width), 
433            with values normalized to [0.0, 1.0].
434
435        Notes:
436        - The output passes through a `sigmoid()` activation, ensuring all pixel values ∈ [0, 1].
437        - Designed for physics-informed or visual reconstruction tasks where local and global consistency are important.
438        """
439        x_input = x
440
441        # patch embedding / tokenization
442        x, (height, width) = self.patch_embedding(x)
443
444        # encoding / add positional information
445        x = self.positional_encoding(x)
446
447        x = self.dropout(x)
448        
449        # transformer blocks
450        for transformer_block in self.transformer_blocks:
451            x = transformer_block(x)
452
453        x = self.norm(x)
454
455        # translation to image
456        x = self.to_img(x)
457
458        # return it in the right format: [B, C, H, W]
459        x = x.transpose(1, 2).reshape(x.shape[0], self.output_channels, self.patch_size*height, self.patch_size*width)
460        # when you call .view() right after .transpose(), PyTorch can’t reinterpret the data layout -> this is an error.
461
462        # refinement
463        x = self.refinement(x, original=x_input)
464
465        # Other version:
466        # refined = self.refinement(x)
467
468        # # Combine contributions (global + local + input)
469        # x = x + refined + x_input
470
471        return torch.sigmoid(x)  # between 0.0 and 1.0 -> alt: torch.clamp(x, 0.0, 1.0)

Forward pass of the PhysicFormer network.

Parameter:

  • x (torch.tensor): Input image tensor of shape (batch_size, input_channels, height, width).

Returns:

  • torch.tensor: Refined output image tensor of shape (batch_size, output_channels, height, width), with values normalized to [0.0, 1.0].

Notes:

  • The output passes through a sigmoid() activation, ensuring all pixel values ∈ [0, 1].
  • Designed for physics-informed or visual reconstruction tasks where local and global consistency are important.