Source code for torch_frame.nn.encoding.positional_encoding

import torch
from torch import Tensor
from torch.nn import Module


[docs]class PositionalEncoding(Module): r"""Positional encoding introduced in `"Attention Is All You Need" <https://arxiv.org/abs/1706.03762>`_ paper. Given an input tensor of shape :obj:`(*, )`, this encoding expands it into an output tensor of shape :obj:`(*, out_size)`. Args: out_size (int): The output dimension size. """ def __init__(self, out_size: int) -> None: super().__init__() if out_size % 2 != 0: raise ValueError( f"out_size should be divisible by 2 (got {out_size}).") self.out_size = out_size self.mult_term: Tensor self.register_buffer( "mult_term", torch.pow( 1 / 10000.0, torch.arange(0, self.out_size, 2) / out_size, ), )
[docs] def forward(self, input_tensor: Tensor) -> Tensor: assert torch.all(input_tensor >= 0) # (*, 1) * (1, ..., 1, out_size // 2) -> (*, out_size // 2) mult_tensor = input_tensor.unsqueeze(-1) * self.mult_term.reshape( (1, ) * input_tensor.ndim + (-1, )) # cat([(*, out_size // 2), (*, out_size // 2)]) -> (*, out_size) return torch.cat([torch.sin(mult_tensor), torch.cos(mult_tensor)], dim=-1)