xlmr_m18.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from transformers import BertPreTrainedModel,BertConfig
  2. import torch.nn as nn
  3. import torch
  4. from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
  5. from transformers import XLMRobertaModel,XLMRobertaTokenizer
  6. from typing import Optional
  7. from modules import torch_utils
  8. class BertSeriesConfig(BertConfig):
  9. def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
  10. super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
  11. self.project_dim = project_dim
  12. self.pooler_fn = pooler_fn
  13. self.learn_encoder = learn_encoder
  14. class RobertaSeriesConfig(XLMRobertaConfig):
  15. def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
  16. super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
  17. self.project_dim = project_dim
  18. self.pooler_fn = pooler_fn
  19. self.learn_encoder = learn_encoder
  20. class BertSeriesModelWithTransformation(BertPreTrainedModel):
  21. _keys_to_ignore_on_load_unexpected = [r"pooler"]
  22. _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
  23. config_class = BertSeriesConfig
  24. def __init__(self, config=None, **kargs):
  25. # modify initialization for autoloading
  26. if config is None:
  27. config = XLMRobertaConfig()
  28. config.attention_probs_dropout_prob= 0.1
  29. config.bos_token_id=0
  30. config.eos_token_id=2
  31. config.hidden_act='gelu'
  32. config.hidden_dropout_prob=0.1
  33. config.hidden_size=1024
  34. config.initializer_range=0.02
  35. config.intermediate_size=4096
  36. config.layer_norm_eps=1e-05
  37. config.max_position_embeddings=514
  38. config.num_attention_heads=16
  39. config.num_hidden_layers=24
  40. config.output_past=True
  41. config.pad_token_id=1
  42. config.position_embedding_type= "absolute"
  43. config.type_vocab_size= 1
  44. config.use_cache=True
  45. config.vocab_size= 250002
  46. config.project_dim = 1024
  47. config.learn_encoder = False
  48. super().__init__(config)
  49. self.roberta = XLMRobertaModel(config)
  50. self.transformation = nn.Linear(config.hidden_size,config.project_dim)
  51. # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  52. self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
  53. # self.pooler = lambda x: x[:,0]
  54. # self.post_init()
  55. self.has_pre_transformation = True
  56. if self.has_pre_transformation:
  57. self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
  58. self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  59. self.post_init()
  60. def encode(self,c):
  61. device = torch_utils.get_param(self).device
  62. text = self.tokenizer(c,
  63. truncation=True,
  64. max_length=77,
  65. return_length=False,
  66. return_overflowing_tokens=False,
  67. padding="max_length",
  68. return_tensors="pt")
  69. text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
  70. text["attention_mask"] = torch.tensor(
  71. text['attention_mask']).to(device)
  72. features = self(**text)
  73. return features['projection_state']
  74. def forward(
  75. self,
  76. input_ids: Optional[torch.Tensor] = None,
  77. attention_mask: Optional[torch.Tensor] = None,
  78. token_type_ids: Optional[torch.Tensor] = None,
  79. position_ids: Optional[torch.Tensor] = None,
  80. head_mask: Optional[torch.Tensor] = None,
  81. inputs_embeds: Optional[torch.Tensor] = None,
  82. encoder_hidden_states: Optional[torch.Tensor] = None,
  83. encoder_attention_mask: Optional[torch.Tensor] = None,
  84. output_attentions: Optional[bool] = None,
  85. return_dict: Optional[bool] = None,
  86. output_hidden_states: Optional[bool] = None,
  87. ) :
  88. r"""
  89. """
  90. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  91. outputs = self.roberta(
  92. input_ids=input_ids,
  93. attention_mask=attention_mask,
  94. token_type_ids=token_type_ids,
  95. position_ids=position_ids,
  96. head_mask=head_mask,
  97. inputs_embeds=inputs_embeds,
  98. encoder_hidden_states=encoder_hidden_states,
  99. encoder_attention_mask=encoder_attention_mask,
  100. output_attentions=output_attentions,
  101. output_hidden_states=True,
  102. return_dict=return_dict,
  103. )
  104. # # last module outputs
  105. # sequence_output = outputs[0]
  106. # # project every module
  107. # sequence_output_ln = self.pre_LN(sequence_output)
  108. # # pooler
  109. # pooler_output = self.pooler(sequence_output_ln)
  110. # pooler_output = self.transformation(pooler_output)
  111. # projection_state = self.transformation(outputs.last_hidden_state)
  112. if self.has_pre_transformation:
  113. sequence_output2 = outputs["hidden_states"][-2]
  114. sequence_output2 = self.pre_LN(sequence_output2)
  115. projection_state2 = self.transformation_pre(sequence_output2)
  116. return {
  117. "projection_state": projection_state2,
  118. "last_hidden_state": outputs.last_hidden_state,
  119. "hidden_states": outputs.hidden_states,
  120. "attentions": outputs.attentions,
  121. }
  122. else:
  123. projection_state = self.transformation(outputs.last_hidden_state)
  124. return {
  125. "projection_state": projection_state,
  126. "last_hidden_state": outputs.last_hidden_state,
  127. "hidden_states": outputs.hidden_states,
  128. "attentions": outputs.attentions,
  129. }
  130. # return {
  131. # 'pooler_output':pooler_output,
  132. # 'last_hidden_state':outputs.last_hidden_state,
  133. # 'hidden_states':outputs.hidden_states,
  134. # 'attentions':outputs.attentions,
  135. # 'projection_state':projection_state,
  136. # 'sequence_out': sequence_output
  137. # }
  138. class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
  139. base_model_prefix = 'roberta'
  140. config_class= RobertaSeriesConfig