| from typing import Tuple |
| |
| import torch |
| from torch import Tensor |
| |
| |
| def milstm_cell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias): |
| Wx = x.mm(w_ih.t()) |
| Uz = hx.mm(w_hh.t()) |
| |
| # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf |
| gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias |
| |
| # Same as LSTMCell after this point |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = ingate.sigmoid() |
| forgetgate = forgetgate.sigmoid() |
| cellgate = cellgate.tanh() |
| outgate = outgate.sigmoid() |
| |
| cy = (forgetgate * cx) + (ingate * cellgate) |
| hy = outgate * cy.tanh() |
| |
| return hy, cy |
| |
| |
| def lstm_cell( |
| input: Tensor, |
| hidden: Tuple[Tensor, Tensor], |
| w_ih: Tensor, |
| w_hh: Tensor, |
| b_ih: Tensor, |
| b_hh: Tensor, |
| ) -> Tuple[Tensor, Tensor]: |
| hx, cx = hidden |
| gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh |
| |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = torch.sigmoid(ingate) |
| forgetgate = torch.sigmoid(forgetgate) |
| cellgate = torch.tanh(cellgate) |
| outgate = torch.sigmoid(outgate) |
| |
| cy = (forgetgate * cx) + (ingate * cellgate) |
| hy = outgate * torch.tanh(cy) |
| |
| return hy, cy |
| |
| |
| def flat_lstm_cell( |
| input: Tensor, |
| hx: Tensor, |
| cx: Tensor, |
| w_ih: Tensor, |
| w_hh: Tensor, |
| b_ih: Tensor, |
| b_hh: Tensor, |
| ) -> Tuple[Tensor, Tensor]: |
| gates = torch.mm(input, w_ih.t()) + torch.mm(hx, w_hh.t()) + b_ih + b_hh |
| |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = torch.sigmoid(ingate) |
| forgetgate = torch.sigmoid(forgetgate) |
| cellgate = torch.tanh(cellgate) |
| outgate = torch.sigmoid(outgate) |
| |
| cy = (forgetgate * cx) + (ingate * cellgate) |
| hy = outgate * torch.tanh(cy) |
| |
| return hy, cy |
| |
| |
| def premul_lstm_cell( |
| igates: Tensor, |
| hidden: Tuple[Tensor, Tensor], |
| w_hh: Tensor, |
| b_ih: Tensor, |
| b_hh: Tensor, |
| ) -> Tuple[Tensor, Tensor]: |
| hx, cx = hidden |
| gates = igates + torch.mm(hx, w_hh.t()) + b_ih + b_hh |
| |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = torch.sigmoid(ingate) |
| forgetgate = torch.sigmoid(forgetgate) |
| cellgate = torch.tanh(cellgate) |
| outgate = torch.sigmoid(outgate) |
| |
| cy = (forgetgate * cx) + (ingate * cellgate) |
| hy = outgate * torch.tanh(cy) |
| |
| return hy, cy |
| |
| |
| def premul_lstm_cell_no_bias( |
| igates: Tensor, hidden: Tuple[Tensor, Tensor], w_hh: Tensor, b_hh: Tensor |
| ) -> Tuple[Tensor, Tensor]: |
| hx, cx = hidden |
| gates = igates + torch.mm(hx, w_hh.t()) + b_hh |
| |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = torch.sigmoid(ingate) |
| forgetgate = torch.sigmoid(forgetgate) |
| cellgate = torch.tanh(cellgate) |
| outgate = torch.sigmoid(outgate) |
| |
| cy = (forgetgate * cx) + (ingate * cellgate) |
| hy = outgate * torch.tanh(cy) |
| |
| return hy, cy |
| |
| |
| def gru_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): |
| gi = torch.mm(input, w_ih.t()) + b_ih |
| gh = torch.mm(hidden, w_hh.t()) + b_hh |
| i_r, i_i, i_n = gi.chunk(3, 1) |
| h_r, h_i, h_n = gh.chunk(3, 1) |
| |
| resetgate = torch.sigmoid(i_r + h_r) |
| inputgate = torch.sigmoid(i_i + h_i) |
| newgate = torch.tanh(i_n + resetgate * h_n) |
| hy = newgate + inputgate * (hidden - newgate) |
| |
| return hy |
| |
| |
| def rnn_relu_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): |
| igates = torch.mm(input, w_ih.t()) + b_ih |
| hgates = torch.mm(hidden, w_hh.t()) + b_hh |
| return torch.relu(igates + hgates) |
| |
| |
| def rnn_tanh_cell(input, hidden, w_ih, w_hh, b_ih, b_hh): |
| igates = torch.mm(input, w_ih.t()) + b_ih |
| hgates = torch.mm(hidden, w_hh.t()) + b_hh |
| return torch.tanh(igates + hgates) |