TARS-v1 / chat_with_tars.py
Ubuntu
Initial commit for experimental-tars branch
b404f80
import os
import torch
import torch.nn as nn
from transformers import BertModel, GPTNeoForCausalLM, AutoTokenizer
# βš™οΈ Ensure temporary directory is writable
os.environ["TMPDIR"] = os.path.expanduser("~/tmp")
os.makedirs(os.environ["TMPDIR"], exist_ok=True)
# πŸ’  Optional modules (brain & heart, if available)
heart_module = None
brain_module = None
if os.path.isdir("heart"):
try:
from heart import heart
heart_module = heart
except Exception as e:
print(f"[⚠️] Heart module error: {e}")
if os.path.isdir("brain"):
try:
from brain import brain
brain_module = brain
except Exception as e:
print(f"[⚠️] Brain module error: {e}")
# TARSQuantumHybrid Class
class TARSQuantumHybrid(nn.Module):
def __init__(self, bert_model="bert-base-uncased", gpt_model="EleutherAI/gpt-neo-125M"):
super(TARSQuantumHybrid, self).__init__()
self.bert = BertModel.from_pretrained(bert_model)
self.gpt = GPTNeoForCausalLM.from_pretrained(gpt_model)
gpt_hidden_dim = getattr(self.gpt.config, "hidden_size", None) or getattr(self.gpt.config, "n_embd", 768)
self.embedding_proj = nn.Linear(self.bert.config.hidden_size, gpt_hidden_dim)
self.tokenizer = AutoTokenizer.from_pretrained(gpt_model)
# Ensure the tokenizer has a padding token
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
self.gpt.resize_token_embeddings(len(self.tokenizer))
print("βœ… Padding token added and model resized.")
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
cls_embedding = bert_output.last_hidden_state[:, 0, :]
gpt_input = self.embedding_proj(cls_embedding).unsqueeze(1)
outputs = self.gpt(inputs_embeds=gpt_input, decoder_input_ids=decoder_input_ids)
return outputs
def chat(self, text, max_length=128):
# 🧠 Tokenize the input text
cleaned_text = self.clean_input_text(text)
if not cleaned_text.strip():
return "πŸ€– Please provide a non-empty input."
encoded_input = self.safe_tokenization(cleaned_text)
# Extract input_ids and attention_mask
input_ids = encoded_input["input_ids"]
attention_mask = encoded_input["attention_mask"]
# Debug: Check the token IDs and vocab size
print(f"Input Text: {cleaned_text}")
print(f"Input IDs: {input_ids}")
print(f"Vocabulary Size: {self.tokenizer.vocab_size}")
# Ensure token IDs are within bounds
if input_ids.numel() > 0 and input_ids.max() >= self.tokenizer.vocab_size:
raise ValueError(f"Token ID exceeds model's vocabulary size: {input_ids.max()}")
decoder_input_ids = torch.tensor([[self.tokenizer.bos_token_id]])
# πŸ§ͺ Generate output using the model
with torch.no_grad():
outputs = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
)
generated_ids = torch.argmax(outputs.logits, dim=-1)
# Debug: Check the generated token IDs
print(f"Generated Token IDs: {generated_ids}")
raw_response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# 🧼 Clean model echo by removing the original input from the response
cleaned = raw_response.replace(cleaned_text, "").strip()
# 🧠 Add insights from optional modules (brain & heart)
extra_thoughts = ""
if brain_module and hasattr(brain_module, "get_brain_insight"):
extra_thoughts += f"\n🧠 {brain_module.get_brain_insight()}"
if heart_module and hasattr(heart_module, "get_heart_feeling"):
extra_thoughts += f"\n❀️ {heart_module.get_heart_feeling()}"
# πŸͺ„ Return final response
final_response = cleaned if cleaned else "πŸ€– ...processing quantum entanglement..."
return final_response + extra_thoughts
def clean_input_text(self, text):
# Remove unwanted characters
cleaned_text = ''.join(e for e in text if e.isalnum() or e.isspace())
return cleaned_text
def safe_tokenization(self, text):
token_ids = self.tokenizer.encode(text, add_special_tokens=True)
# Ensure that token ids are within vocabulary size
token_ids = [min(i, self.tokenizer.vocab_size - 1) for i in token_ids]
return {
"input_ids": torch.tensor(token_ids).unsqueeze(0), # Adding batch dimension
"attention_mask": torch.ones((1, len(token_ids)), dtype=torch.long)
}
# βœ… Torch-compatible loader
def load_tars(path="tars_v1.pt"):
from torch.serialization import add_safe_globals
add_safe_globals({"TARSQuantumHybrid": TARSQuantumHybrid})
model = torch.load(path, weights_only=False)
model.eval()
return model
# βœ… Start chat loop
if __name__ == "__main__":
print("πŸ€– TARS model loaded successfully. Ready to chat!")
model = load_tars()
while True:
prompt = input("You: ")
if prompt.strip().lower() in ["exit", "quit"]:
print("TARS: Till we meet again in the quantum field. 🌌")
break
response = model.chat(prompt)
print(f"TARS: {response}")