xlmr.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from transformers import BertPreTrainedModel, BertConfig
  2. import torch.nn as nn
  3. import torch
  4. from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
  5. from transformers import XLMRobertaModel,XLMRobertaTokenizer
  6. from typing import Optional
  7. from modules.torch_utils import get_param
  8. class BertSeriesConfig(BertConfig):
  9. def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
  10. super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
  11. self.project_dim = project_dim
  12. self.pooler_fn = pooler_fn
  13. self.learn_encoder = learn_encoder
  14. class RobertaSeriesConfig(XLMRobertaConfig):
  15. def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
  16. super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
  17. self.project_dim = project_dim
  18. self.pooler_fn = pooler_fn
  19. self.learn_encoder = learn_encoder
  20. class BertSeriesModelWithTransformation(BertPreTrainedModel):
  21. _keys_to_ignore_on_load_unexpected = [r"pooler"]
  22. _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
  23. config_class = BertSeriesConfig
  24. def __init__(self, config=None, **kargs):
  25. # modify initialization for autoloading
  26. if config is None:
  27. config = XLMRobertaConfig()
  28. config.attention_probs_dropout_prob= 0.1
  29. config.bos_token_id=0
  30. config.eos_token_id=2
  31. config.hidden_act='gelu'
  32. config.hidden_dropout_prob=0.1
  33. config.hidden_size=1024
  34. config.initializer_range=0.02
  35. config.intermediate_size=4096
  36. config.layer_norm_eps=1e-05
  37. config.max_position_embeddings=514
  38. config.num_attention_heads=16
  39. config.num_hidden_layers=24
  40. config.output_past=True
  41. config.pad_token_id=1
  42. config.position_embedding_type= "absolute"
  43. config.type_vocab_size= 1
  44. config.use_cache=True
  45. config.vocab_size= 250002
  46. config.project_dim = 768
  47. config.learn_encoder = False
  48. super().__init__(config)
  49. self.roberta = XLMRobertaModel(config)
  50. self.transformation = nn.Linear(config.hidden_size,config.project_dim)
  51. self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  52. self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
  53. self.pooler = lambda x: x[:,0]
  54. self.post_init()
  55. def encode(self,c):
  56. device = get_param(self).device
  57. text = self.tokenizer(c,
  58. truncation=True,
  59. max_length=77,
  60. return_length=False,
  61. return_overflowing_tokens=False,
  62. padding="max_length",
  63. return_tensors="pt")
  64. text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
  65. text["attention_mask"] = torch.tensor(
  66. text['attention_mask']).to(device)
  67. features = self(**text)
  68. return features['projection_state']
  69. def forward(
  70. self,
  71. input_ids: Optional[torch.Tensor] = None,
  72. attention_mask: Optional[torch.Tensor] = None,
  73. token_type_ids: Optional[torch.Tensor] = None,
  74. position_ids: Optional[torch.Tensor] = None,
  75. head_mask: Optional[torch.Tensor] = None,
  76. inputs_embeds: Optional[torch.Tensor] = None,
  77. encoder_hidden_states: Optional[torch.Tensor] = None,
  78. encoder_attention_mask: Optional[torch.Tensor] = None,
  79. output_attentions: Optional[bool] = None,
  80. return_dict: Optional[bool] = None,
  81. output_hidden_states: Optional[bool] = None,
  82. ) :
  83. r"""
  84. """
  85. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  86. outputs = self.roberta(
  87. input_ids=input_ids,
  88. attention_mask=attention_mask,
  89. token_type_ids=token_type_ids,
  90. position_ids=position_ids,
  91. head_mask=head_mask,
  92. inputs_embeds=inputs_embeds,
  93. encoder_hidden_states=encoder_hidden_states,
  94. encoder_attention_mask=encoder_attention_mask,
  95. output_attentions=output_attentions,
  96. output_hidden_states=True,
  97. return_dict=return_dict,
  98. )
  99. # last module outputs
  100. sequence_output = outputs[0]
  101. # project every module
  102. sequence_output_ln = self.pre_LN(sequence_output)
  103. # pooler
  104. pooler_output = self.pooler(sequence_output_ln)
  105. pooler_output = self.transformation(pooler_output)
  106. projection_state = self.transformation(outputs.last_hidden_state)
  107. return {
  108. 'pooler_output':pooler_output,
  109. 'last_hidden_state':outputs.last_hidden_state,
  110. 'hidden_states':outputs.hidden_states,
  111. 'attentions':outputs.attentions,
  112. 'projection_state':projection_state,
  113. 'sequence_out': sequence_output
  114. }
  115. class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
  116. base_model_prefix = 'roberta'
  117. config_class= RobertaSeriesConfig