from typing import Optional
from torch import Tensor
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.nn.modules.loss import BCEWithLogitsLoss
 
 
class GatedBCEWithLogitsLoss(BCEWithLogitsLoss):
    """Gated Binary Cross Entropy with Logits Loss for ignoring specific indices."""
 
    def __init__(
        self,
        weight: Optional[Tensor] = None,
        size_average: Optional[bool] = None,
        reduce: Optional[bool] = None,
        reduction: str = "mean",
        pos_weight: Optional[Tensor] = None,
        ignore_index: int = -100,
    ):
        super().__init__(weight, size_average, reduce, reduction, pos_weight)
        self.register_buffer("weight", weight)
        self.register_buffer("pos_weight", pos_weight)
        self.weight: Optional[Tensor]
        self.pos_weight: Optional[Tensor]
        self.ignore_index = ignore_index
 
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        mask = target != self.ignore_index
        results_per_sample = binary_cross_entropy_with_logits(
            input, target, self.weight, pos_weight=self.pos_weight, reduction="none"
        )
        if self.reduction == "mean":
            return results_per_sample[mask].mean()
        if self.reduction == "sum":
            return results_per_sample[mask].sum()
        raise ValueError(f"Unknown reduction: {self.reduction}")