update float32 to float64
This commit is contained in:
@@ -82,7 +82,7 @@ class MultiHeadAttention(nn.Module):
|
||||
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
|
||||
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
|
||||
if mask is not None:
|
||||
fill_value = torch.finfo(torch.float32).min
|
||||
fill_value = torch.finfo(torch.float64).min
|
||||
energy.mask_fill(~mask, fill_value)
|
||||
|
||||
scaling = self.emb_size ** (1 / 2)
|
||||
|
||||
Reference in New Issue
Block a user