| |
| |
| |
| |
| |
| import functools |
| from caffe2.python import brew, rnn_cell |
| |
| |
| class GRUCell(rnn_cell.RNNCell): |
| |
| def __init__( |
| self, |
| input_size, |
| hidden_size, |
| forget_bias, # Currently unused! Values here will be ignored. |
| memory_optimization, |
| drop_states=False, |
| linear_before_reset=False, |
| **kwargs |
| ): |
| super(GRUCell, self).__init__(**kwargs) |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.forget_bias = float(forget_bias) |
| self.memory_optimization = memory_optimization |
| self.drop_states = drop_states |
| self.linear_before_reset = linear_before_reset |
| |
| # Unlike LSTMCell, GRUCell needs the output of one gate to feed into another. |
| # (reset gate -> output_gate) |
| # So, much of the logic to calculate the reset gate output and modified |
| # output gate input is set here, in the graph definition. |
| # The remaining logic lives in gru_unit_op.{h,cc}. |
| def _apply( |
| self, |
| model, |
| input_t, |
| seq_lengths, |
| states, |
| timestep, |
| extra_inputs=None, |
| ): |
| hidden_t_prev = states[0] |
| |
| # Split input tensors to get inputs for each gate. |
| input_t_reset, input_t_update, input_t_output = model.net.Split( |
| [ |
| input_t, |
| ], |
| [ |
| self.scope('input_t_reset'), |
| self.scope('input_t_update'), |
| self.scope('input_t_output'), |
| ], |
| axis=2, |
| ) |
| |
| # Fully connected layers for reset and update gates. |
| reset_gate_t = brew.fc( |
| model, |
| hidden_t_prev, |
| self.scope('reset_gate_t'), |
| dim_in=self.hidden_size, |
| dim_out=self.hidden_size, |
| axis=2, |
| ) |
| update_gate_t = brew.fc( |
| model, |
| hidden_t_prev, |
| self.scope('update_gate_t'), |
| dim_in=self.hidden_size, |
| dim_out=self.hidden_size, |
| axis=2, |
| ) |
| |
| # Calculating the modified hidden state going into output gate. |
| reset_gate_t = model.net.Sum( |
| [reset_gate_t, input_t_reset], |
| self.scope('reset_gate_t') |
| ) |
| reset_gate_t_sigmoid = model.net.Sigmoid( |
| reset_gate_t, |
| self.scope('reset_gate_t_sigmoid') |
| ) |
| |
| # `self.linear_before_reset = True` matches cudnn semantics |
| if self.linear_before_reset: |
| output_gate_fc = brew.fc( |
| model, |
| hidden_t_prev, |
| self.scope('output_gate_t'), |
| dim_in=self.hidden_size, |
| dim_out=self.hidden_size, |
| axis=2, |
| ) |
| output_gate_t = model.net.Mul( |
| [reset_gate_t_sigmoid, output_gate_fc], |
| self.scope('output_gate_t_mul') |
| ) |
| else: |
| modified_hidden_t_prev = model.net.Mul( |
| [reset_gate_t_sigmoid, hidden_t_prev], |
| self.scope('modified_hidden_t_prev') |
| ) |
| output_gate_t = brew.fc( |
| model, |
| modified_hidden_t_prev, |
| self.scope('output_gate_t'), |
| dim_in=self.hidden_size, |
| dim_out=self.hidden_size, |
| axis=2, |
| ) |
| |
| # Add input contributions to update and output gate. |
| # We already (in-place) added input contributions to the reset gate. |
| update_gate_t = model.net.Sum( |
| [update_gate_t, input_t_update], |
| self.scope('update_gate_t'), |
| ) |
| output_gate_t = model.net.Sum( |
| [output_gate_t, input_t_output], |
| self.scope('output_gate_t_summed'), |
| ) |
| |
| # Join gate outputs and add input contributions |
| gates_t, _gates_t_concat_dims = model.net.Concat( |
| [ |
| reset_gate_t, |
| update_gate_t, |
| output_gate_t, |
| ], |
| [ |
| self.scope('gates_t'), |
| self.scope('_gates_t_concat_dims'), |
| ], |
| axis=2, |
| ) |
| |
| if seq_lengths is not None: |
| inputs = [hidden_t_prev, gates_t, seq_lengths, timestep] |
| else: |
| inputs = [hidden_t_prev, gates_t, timestep] |
| |
| hidden_t = model.net.GRUUnit( |
| inputs, |
| list(self.get_state_names()), |
| forget_bias=self.forget_bias, |
| drop_states=self.drop_states, |
| sequence_lengths=(seq_lengths is not None), |
| ) |
| model.net.AddExternalOutputs(hidden_t) |
| return (hidden_t,) |
| |
| def prepare_input(self, model, input_blob): |
| return brew.fc( |
| model, |
| input_blob, |
| self.scope('i2h'), |
| dim_in=self.input_size, |
| dim_out=3 * self.hidden_size, |
| axis=2, |
| ) |
| |
| def get_state_names(self): |
| return (self.scope('hidden_t'),) |
| |
| def get_output_dim(self): |
| return self.hidden_size |
| |
| |
| GRU = functools.partial(rnn_cell._LSTM, GRUCell) |