・ Language: Python3 -Library: transformers
The code is here
import transformers
import torch
In most cases, the strings you enter in your model will not be the same length. On the other hand, in order to perform tensor calculation with this string data in the model, the lengths must be the same. So, decide the maximum value, and if it does not reach that length, fill it with padding characters to that length. (Next section)
MAX_LENGTH = 192
The explanation is here.
tokenizer = transformers.AutoTokenizer.from_pretrained("roberta-base")
text = "This is a pen."
text2 = "I am a man"
ids = tokenizer.encode(text)
ids2 = tokenizer.encode(text2)
token_ids = tokenizer.build_inputs_with_special_tokens(ids, ids2)
Putting is as explained in the previous section. Attention Mask is a character that tells the model how effective the character is and where the padding comes from. It is "1" for valid characters and "0" for padding characters.
#Attention Mask
mask = [1] * len(token_ids)
#Padding
padding_length = MAX_LENGTH - len(token_ids)
if padding_length > 0:
token_ids = token_ids + ([1] * padding_length)
mask = mask + ([0] * padding_length)
You can generate that model by giving "roberta-base" a different model name. Other models are here.
model = transformers.AutoModel.from_pretrained("roberta-base")
I have reached the point where the character string entered so far is used as the ID. Since it is a list type, I will make it a torch.tensor type. When input to the model, the output of (1) the final layer of BertLayer and (2) the output of (1) processed by BertPooler are output. The size of each is as shown in the output result of the code below.
#A type that allows you to enter an ID and mask in model(list -> pytorch.tenrsor)Conversion to
token_ids_tensor = torch.tensor([token_ids], dtype=torch.long)
mask_tensor = torch.tensor([mask], dtype=torch.long)
#conversion
out = model(input_ids=token_ids_tensor, attention_mask=mask_tensor)
print(out[0].shape)
#output
#torch.Size([1, 192, 768])
print(out[1].shape)
#output
#torch.Size([1, 768])
Recommended Posts