| import torch |
| from torch.utils.data import Dataset |
| |
| |
| def collate_sentences_lm(samples): |
| |
| if len(samples) == 0: |
| return {} |
| |
| id = torch.LongTensor([s["id"] for s in samples]) |
| src_tokens = torch.stack([s["source"] for s in samples], 0) |
| tgt_tokens = torch.stack([s["target"] for s in samples], 0) |
| ntokens = len(samples) * len(samples[0]["target"]) |
| src_lengths = torch.LongTensor([len(samples[0]["source"])] * len(samples)) |
| |
| batch = { |
| "id": id, |
| "nsentences": len(samples), |
| "ntokens": ntokens, |
| "input": src_tokens, |
| "target": tgt_tokens, |
| } |
| return batch |
| |
| |
| class BenchmarkLMDataset(Dataset): |
| """ |
| Dataset to benchmark a translation like seq2seq task. |
| Args: |
| vocab_size (int, optional): size of the vocabulary (default 10000). |
| max_source_positions (int, optional): max number of tokens in the |
| source sentence (default: 1024). |
| total_samples (int, optional): the total number of rows in the |
| dataset (default: 10000). |
| """ |
| |
| def __init__( |
| self, vocab_size=10000, max_source_positions=1024, total_samples=10000, |
| ): |
| self.vocab_size = vocab_size |
| self.max_source_positions = max_source_positions |
| self.total_samples = total_samples |
| self.sizes = [self.max_source_positions] * self.total_samples |
| |
| def __getitem__(self, index): |
| length = self.sizes[index] |
| source = torch.randint(1, self.vocab_size, (length,)) |
| target = source.clone() |
| return { |
| "id": index, |
| "source": source, |
| "target": target, |
| } |
| |
| def __len__(self): |
| return self.total_samples |