File size: 5,517 Bytes
b404f80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import torch
import torch.nn as nn
from transformers import BertModel, GPTNeoForCausalLM, AutoTokenizer
# ⚙️ Ensure temporary directory is writable
os.environ["TMPDIR"] = os.path.expanduser("~/tmp")
os.makedirs(os.environ["TMPDIR"], exist_ok=True)
# 💠 Optional modules (brain & heart, if available)
heart_module = None
brain_module = None
if os.path.isdir("heart"):
try:
from heart import heart
heart_module = heart
except Exception as e:
print(f"[⚠️] Heart module error: {e}")
if os.path.isdir("brain"):
try:
from brain import brain
brain_module = brain
except Exception as e:
print(f"[⚠️] Brain module error: {e}")
# TARSQuantumHybrid Class
class TARSQuantumHybrid(nn.Module):
def __init__(self, bert_model="bert-base-uncased", gpt_model="EleutherAI/gpt-neo-125M"):
super(TARSQuantumHybrid, self).__init__()
self.bert = BertModel.from_pretrained(bert_model)
self.gpt = GPTNeoForCausalLM.from_pretrained(gpt_model)
gpt_hidden_dim = getattr(self.gpt.config, "hidden_size", None) or getattr(self.gpt.config, "n_embd", 768)
self.embedding_proj = nn.Linear(self.bert.config.hidden_size, gpt_hidden_dim)
self.tokenizer = AutoTokenizer.from_pretrained(gpt_model)
# Ensure the tokenizer has a padding token
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.gpt.resize_token_embeddings(len(self.tokenizer))
print("✅ Padding token added and model resized.")
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_embedding = bert_output.last_hidden_state[:, 0, :]
gpt_input = self.embedding_proj(cls_embedding).unsqueeze(1)
outputs = self.gpt(inputs_embeds=gpt_input, decoder_input_ids=decoder_input_ids)
return outputs
def chat(self, text, max_length=128):
# 🧠 Tokenize the input text
cleaned_text = self.clean_input_text(text)
if not cleaned_text.strip():
return "🤖 Please provide a non-empty input."
encoded_input = self.safe_tokenization(cleaned_text)
# Extract input_ids and attention_mask
input_ids = encoded_input["input_ids"]
attention_mask = encoded_input["attention_mask"]
# Debug: Check the token IDs and vocab size
print(f"Input Text: {cleaned_text}")
print(f"Input IDs: {input_ids}")
print(f"Vocabulary Size: {self.tokenizer.vocab_size}")
# Ensure token IDs are within bounds
if input_ids.numel() > 0 and input_ids.max() >= self.tokenizer.vocab_size:
raise ValueError(f"Token ID exceeds model's vocabulary size: {input_ids.max()}")
decoder_input_ids = torch.tensor([[self.tokenizer.bos_token_id]])
# 🧪 Generate output using the model
with torch.no_grad():
outputs = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
)
generated_ids = torch.argmax(outputs.logits, dim=-1)
# Debug: Check the generated token IDs
print(f"Generated Token IDs: {generated_ids}")
raw_response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# 🧼 Clean model echo by removing the original input from the response
cleaned = raw_response.replace(cleaned_text, "").strip()
# 🧠 Add insights from optional modules (brain & heart)
extra_thoughts = ""
if brain_module and hasattr(brain_module, "get_brain_insight"):
extra_thoughts += f"\n🧠 {brain_module.get_brain_insight()}"
if heart_module and hasattr(heart_module, "get_heart_feeling"):
extra_thoughts += f"\n❤️ {heart_module.get_heart_feeling()}"
# 🪄 Return final response
final_response = cleaned if cleaned else "🤖 ...processing quantum entanglement..."
return final_response + extra_thoughts
def clean_input_text(self, text):
# Remove unwanted characters
cleaned_text = ''.join(e for e in text if e.isalnum() or e.isspace())
return cleaned_text
def safe_tokenization(self, text):
token_ids = self.tokenizer.encode(text, add_special_tokens=True)
# Ensure that token ids are within vocabulary size
token_ids = [min(i, self.tokenizer.vocab_size - 1) for i in token_ids]
return {
"input_ids": torch.tensor(token_ids).unsqueeze(0), # Adding batch dimension
"attention_mask": torch.ones((1, len(token_ids)), dtype=torch.long)
}
# ✅ Torch-compatible loader
def load_tars(path="tars_v1.pt"):
from torch.serialization import add_safe_globals
add_safe_globals({"TARSQuantumHybrid": TARSQuantumHybrid})
model = torch.load(path, weights_only=False)
model.eval()
return model
# ✅ Start chat loop
if __name__ == "__main__":
print("🤖 TARS model loaded successfully. Ready to chat!")
model = load_tars()
while True:
prompt = input("You: ")
if prompt.strip().lower() in ["exit", "quit"]:
print("TARS: Till we meet again in the quantum field. 🌌")
break
response = model.chat(prompt)
print(f"TARS: {response}") |