| import torch.cuda |
| |
| try: |
| from torch._C import _cudnn |
| except ImportError: |
| # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(), |
| # so it's safe to not emit any checks here. |
| _cudnn = None # type: ignore[assignment] |
| |
| |
| def get_cudnn_mode(mode): |
| if mode == 'RNN_RELU': |
| return int(_cudnn.RNNMode.rnn_relu) |
| elif mode == 'RNN_TANH': |
| return int(_cudnn.RNNMode.rnn_tanh) |
| elif mode == 'LSTM': |
| return int(_cudnn.RNNMode.lstm) |
| elif mode == 'GRU': |
| return int(_cudnn.RNNMode.gru) |
| else: |
| raise Exception("Unknown mode: {}".format(mode)) |
| |
| |
| # NB: We don't actually need this class anymore (in fact, we could serialize the |
| # dropout state for even better reproducibility), but it is kept for backwards |
| # compatibility for old models. |
| class Unserializable: |
| |
| def __init__(self, inner): |
| self.inner = inner |
| |
| def get(self): |
| return self.inner |
| |
| def __getstate__(self): |
| # Note: can't return {}, because python2 won't call __setstate__ |
| # if the value evaluates to False |
| return "<unserializable>" |
| |
| def __setstate__(self, state): |
| self.inner = None |
| |
| |
| def init_dropout_state(dropout, train, dropout_seed, dropout_state): |
| dropout_desc_name = 'desc_' + str(torch.cuda.current_device()) |
| dropout_p = dropout if train else 0 |
| if (dropout_desc_name not in dropout_state) or (dropout_state[dropout_desc_name].get() is None): |
| if dropout_p == 0: |
| dropout_state[dropout_desc_name] = Unserializable(None) |
| else: |
| dropout_state[dropout_desc_name] = Unserializable(torch._cudnn_init_dropout_state( # type: ignore[call-arg] |
| dropout_p, |
| train, |
| dropout_seed, |
| self_ty=torch.uint8, |
| device=torch.device('cuda'))) |
| dropout_ts = dropout_state[dropout_desc_name].get() |
| return dropout_ts |