Skip to content

Query related to data iterators for Seq2Seq translation using bert-gpt2 #256

@seekingpeace

Description

@seekingpeace

Hi, while trying to use the following snippet:

import texar.torch as tx
from texar.torch.run import *

# (1) Modeling
class BERTGPT2Model(nn.Module):
  """An encoder-decoder model with GPT-2 as the decoder."""
  def __init__(self, vocab_size):
    super().__init__()
    # Use hyperparameter dict for model configuration
    self.tokeniserBERT = tx.data.BERTTokenizer('bert-base-uncased)
    self.tokeniserGPT2 = tx.data.GPT2Tokenizer('gpt2-medium')
    self.encoder = modules.BERTEncoder('bert-base-uncased')
    self.decoder = tx.modules.GPT2Decoder("gpt2-medium")  # With pre-trained weights

  def _get_decoder_output(self, batch, train=True):
    """Perform model inference, i.e., decoding."""
    enc_states,_ = self.encoder(inputs=self.embedder(batch['source_text_ids']),
                              sequence_length=batch['source_length'])
    if train:  # Teacher-forcing decoding at training time
      return self.decoder(
          inputs=batch['target_text_ids'], sequence_length=batch['target_length'] - 1,
          memory=enc_states, memory_sequence_length=batch['source_length'])
    else:      # Beam search decoding at prediction time
      start_tokens = torch.full_like(batch['source_text_ids'][:, 0], BOS)  # which BOS to use?
      return self.decoder(
          beam_width=5, start_tokens=start_tokens,
          memory=enc_states, memory_sequence_length=batch['source_length'])

  def forward(self, batch):
    """Compute training loss."""
    outputs = self._get_decoder_output(batch)
    loss = tx.losses.sequence_sparse_softmax_cross_entropy(  # Sequence loss
        labels=batch['target_text_ids'][:, 1:], logits=outputs.logits,
        sequence_length=batch['target_length'] - 1)  # Automatic masking
    return {"loss": loss}

  def predict(self, batch):
    """Compute model predictions."""
    sequence, _ = self._get_decoder_output(batch, train=False)
    return {"gen_text_ids": sequence}

  
# (2) Data
# Create dataset splits using built-in data loaders
datasets = {split: tx.data.PairedTextData(hparams=data_hparams[split])
            for split in ["train", "valid", "test"]}

model = BERTGPT2Model(datasets["train"].target_vocab.size)

# (3) Training
# Manage the train-eval loop with the Executor API
executor = Executor(
  model=model, datasets=datasets,
  optimizer={"type": torch.optim.Adam, "kwargs": {"lr": 5e-4}},
  stop_training_on=cond.epoch(20),
  log_every=cond.iteration(100),
  validate_every=cond.epoch(1),
  train_metric=("loss", metric.RunningAverage(10, pred_name="loss")),
  valid_metric=metric.BLEU(pred_name="gen_text_ids", label_name="target_text_ids"),
  save_every=cond.validation(better=True),
  checkpoint_dir="outputs/saved_models/")
executor.train()
executor.test(datasets["test"]) 

In this example

  1. How should i use data iterators from files
  2. Data config for generating the file from source text to tokeniserBERT.encode_text(src) and target text tokeniserGPT2.encode_text(tgt) so that it can pass through the batch.
  3. does PairedTextData has an option to pass different processors in above use case.

TIA

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions