Yanisadel commited on
Commit
95e1897
·
1 Parent(s): 6acf14e

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +13 -4
chatNT.py CHANGED
@@ -590,7 +590,6 @@ class TorchMultiOmicsModel(PreTrainedModel):
590
  def __init__(self, config: ChatNTConfig) -> None:
591
  print("(debug) Entering in class")
592
  if isinstance(config, dict):
593
- print("(debug) going in if condition")
594
  # If config is a dictionary instead of ChatNTConfig (which can happen
595
  # depending how the config was saved), we convert it to the config
596
  config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig(
@@ -598,14 +597,24 @@ class TorchMultiOmicsModel(PreTrainedModel):
598
  )
599
  config["gpt_config"] = GptConfig(**config["gpt_config"])
600
  config["esm_config"] = ESMTransformerConfig(**config["esm_config"])
601
- print("(debug) Type esm_config : ", type(config["esm_config"]))
602
- print("(debug) esm_config : ", config["esm_config"])
603
  config["perceiver_resampler_config"] = PerceiverResamplerConfig(
604
  **config["perceiver_resampler_config"]
605
  )
606
  config = ChatNTConfig(**config) # type: ignore
607
- print("(debug) Type config : ", type(config))
608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609
  print("(debug) config : ", config)
610
  print("(debug) config type : ", type(config))
611
  print("(debug) gpt config : ", config.gpt_config)
 
590
  def __init__(self, config: ChatNTConfig) -> None:
591
  print("(debug) Entering in class")
592
  if isinstance(config, dict):
 
593
  # If config is a dictionary instead of ChatNTConfig (which can happen
594
  # depending how the config was saved), we convert it to the config
595
  config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig(
 
597
  )
598
  config["gpt_config"] = GptConfig(**config["gpt_config"])
599
  config["esm_config"] = ESMTransformerConfig(**config["esm_config"])
 
 
600
  config["perceiver_resampler_config"] = PerceiverResamplerConfig(
601
  **config["perceiver_resampler_config"]
602
  )
603
  config = ChatNTConfig(**config) # type: ignore
 
604
 
605
+ else:
606
+ if isinstance(config.gpt_config, dict):
607
+ config.gpt_config["rope_config"] = RotaryEmbeddingConfig(
608
+ **config.gpt_config.["rope_config"]
609
+ )
610
+ config.gpt_config = GptConfig(**config.gpt_config)
611
+
612
+ if isinstance(config.esm_config, dict):
613
+ config.esm_config = ESMTransformerConfig(**config.esm_config)
614
+
615
+ if isinstance(config.perceiver_resampler_config, dict):
616
+ config.esm_config = PerceiverResamplerConfig(**config.perceiver_resampler_config)
617
+
618
  print("(debug) config : ", config)
619
  print("(debug) config type : ", type(config))
620
  print("(debug) gpt config : ", config.gpt_config)