|
import os |
|
import torch |
|
import torch.nn as nn |
|
from transformers import BertModel, GPTNeoForCausalLM, AutoTokenizer |
|
|
|
|
|
os.environ["TMPDIR"] = os.path.expanduser("~/tmp") |
|
os.makedirs(os.environ["TMPDIR"], exist_ok=True) |
|
|
|
|
|
heart_module = None |
|
brain_module = None |
|
|
|
if os.path.isdir("heart"): |
|
try: |
|
from heart import heart |
|
heart_module = heart |
|
except Exception as e: |
|
print(f"[β οΈ] Heart module error: {e}") |
|
|
|
if os.path.isdir("brain"): |
|
try: |
|
from brain import brain |
|
brain_module = brain |
|
except Exception as e: |
|
print(f"[β οΈ] Brain module error: {e}") |
|
|
|
|
|
class TARSQuantumHybrid(nn.Module): |
|
def __init__(self, bert_model="bert-base-uncased", gpt_model="EleutherAI/gpt-neo-125M"): |
|
super(TARSQuantumHybrid, self).__init__() |
|
self.bert = BertModel.from_pretrained(bert_model) |
|
self.gpt = GPTNeoForCausalLM.from_pretrained(gpt_model) |
|
|
|
gpt_hidden_dim = getattr(self.gpt.config, "hidden_size", None) or getattr(self.gpt.config, "n_embd", 768) |
|
self.embedding_proj = nn.Linear(self.bert.config.hidden_size, gpt_hidden_dim) |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(gpt_model) |
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
|
self.gpt.resize_token_embeddings(len(self.tokenizer)) |
|
print("β
Padding token added and model resized.") |
|
|
|
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None): |
|
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
cls_embedding = bert_output.last_hidden_state[:, 0, :] |
|
gpt_input = self.embedding_proj(cls_embedding).unsqueeze(1) |
|
outputs = self.gpt(inputs_embeds=gpt_input, decoder_input_ids=decoder_input_ids) |
|
return outputs |
|
|
|
def chat(self, text, max_length=128): |
|
|
|
cleaned_text = self.clean_input_text(text) |
|
if not cleaned_text.strip(): |
|
return "π€ Please provide a non-empty input." |
|
|
|
encoded_input = self.safe_tokenization(cleaned_text) |
|
|
|
|
|
input_ids = encoded_input["input_ids"] |
|
attention_mask = encoded_input["attention_mask"] |
|
|
|
|
|
print(f"Input Text: {cleaned_text}") |
|
print(f"Input IDs: {input_ids}") |
|
print(f"Vocabulary Size: {self.tokenizer.vocab_size}") |
|
|
|
|
|
if input_ids.numel() > 0 and input_ids.max() >= self.tokenizer.vocab_size: |
|
raise ValueError(f"Token ID exceeds model's vocabulary size: {input_ids.max()}") |
|
|
|
decoder_input_ids = torch.tensor([[self.tokenizer.bos_token_id]]) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
decoder_input_ids=decoder_input_ids, |
|
) |
|
generated_ids = torch.argmax(outputs.logits, dim=-1) |
|
|
|
|
|
print(f"Generated Token IDs: {generated_ids}") |
|
|
|
raw_response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
|
|
cleaned = raw_response.replace(cleaned_text, "").strip() |
|
|
|
|
|
extra_thoughts = "" |
|
if brain_module and hasattr(brain_module, "get_brain_insight"): |
|
extra_thoughts += f"\nπ§ {brain_module.get_brain_insight()}" |
|
if heart_module and hasattr(heart_module, "get_heart_feeling"): |
|
extra_thoughts += f"\nβ€οΈ {heart_module.get_heart_feeling()}" |
|
|
|
|
|
final_response = cleaned if cleaned else "π€ ...processing quantum entanglement..." |
|
return final_response + extra_thoughts |
|
|
|
def clean_input_text(self, text): |
|
|
|
cleaned_text = ''.join(e for e in text if e.isalnum() or e.isspace()) |
|
return cleaned_text |
|
|
|
def safe_tokenization(self, text): |
|
token_ids = self.tokenizer.encode(text, add_special_tokens=True) |
|
|
|
token_ids = [min(i, self.tokenizer.vocab_size - 1) for i in token_ids] |
|
return { |
|
"input_ids": torch.tensor(token_ids).unsqueeze(0), |
|
"attention_mask": torch.ones((1, len(token_ids)), dtype=torch.long) |
|
} |
|
|
|
|
|
def load_tars(path="tars_v1.pt"): |
|
from torch.serialization import add_safe_globals |
|
add_safe_globals({"TARSQuantumHybrid": TARSQuantumHybrid}) |
|
|
|
model = torch.load(path, weights_only=False) |
|
model.eval() |
|
return model |
|
|
|
|
|
if __name__ == "__main__": |
|
print("π€ TARS model loaded successfully. Ready to chat!") |
|
model = load_tars() |
|
|
|
while True: |
|
prompt = input("You: ") |
|
if prompt.strip().lower() in ["exit", "quit"]: |
|
print("TARS: Till we meet again in the quantum field. π") |
|
break |
|
response = model.chat(prompt) |
|
print(f"TARS: {response}") |