Trainer¶
This module contains algorithms to train and test BERT and Linear Layer + CRF models. Author: Lucas Pavanelli
- class trainer.Trainer(model, batch, is_bert=False, criterion=CrossEntropyLoss(), device='cpu')¶
Trains and tests BERT and Linear Layer + CRF models.
- modeltorch.nn.Module
PyTorch model
- batchint
Batch size.
- is_bertbool
If model is a BERT model or not.
- criteriontorch.nn
PyTorch criterion
- devicetorch.device
PyTorch device
- modeltorch.nn.Module
PyTorch model
- batchint
Batch size.
- is_bertbool
If model is a BERT model or not.
- criteriontorch.nn
PyTorch criterion
- devicetorch.device
PyTorch device
- test(test_data)¶
Tests model using test data.
- test_datalist
List of tuples representing test data.
- list
True and predicted values
- train(train_data, optimizer, epoch)¶
Trains model using train data.
- train_datalist
List of tuples representing train data.
- optimizeroptim.SGD
PyTorch optimizer.
- epochint
Number of epoch