Repo for the search and displace ingest module that takes odf, docx and pdf and transforms it into .md to be used with search and displace operations
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

37 lines
1.5 KiB

  1. import sys
  2. from torch.nn.functional import softmax
  3. from transformers import BertForNextSentencePrediction, BertTokenizer
  4. seq_A = sys.argv[1]
  5. seq_B = sys.argv[2]
  6. # load pretrained model and a pretrained tokenizer
  7. model = BertForNextSentencePrediction.from_pretrained('bert-base-cased')
  8. tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
  9. # encode the two sequences. Particularly, make clear that they must be
  10. # encoded as "one" input to the model by using 'seq_B' as the 'text_pair'
  11. encoded = tokenizer.encode_plus(seq_A, text_pair=seq_B, return_tensors='pt')
  12. # print(encoded)
  13. # {'input_ids': tensor([[ 101, 146, 1176, 18621, 106, 102, 2091, 1128, 1176, 1172, 136, 102]]),
  14. # 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]]),
  15. # 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
  16. # NOTE how the token_type_ids are 0 for all tokens in seq_A and 1 for seq_B,
  17. # this way the model knows which token belongs to which sequence
  18. # a model's output is a tuple, we only need the output tensor containing
  19. # the relationships which is the first item in the tuple
  20. seq_relationship_logits = model(**encoded)[0]
  21. # we still need softmax to convert the logits into probabilities
  22. # index 0: sequence B is a continuation of sequence A
  23. # index 1: sequence B is a random sequence
  24. probs = softmax(seq_relationship_logits, dim=1)
  25. print(probs)
  26. # tensor([[9.9993e-01, 6.7607e-05]], grad_fn=<SoftmaxBackward>)
  27. # very high value for index 0: high probability of seq_B being a continuation of seq_A
  28. # which is what we expect!