File size: 5,681 Bytes
9d83cd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import logging
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Suppress TensorFlow logging
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" # Disable oneDNN optimizations
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import warnings
warnings.filterwarnings("ignore", message="A NumPy version >=")
logging.basicConfig(level=logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
# Check if Flash Attention is available
try:
import flash_attn # noqa: F401
flash_attn_exists = True
except ImportError:
flash_attn_exists = False
# Define the DeepthoughtModel class
class DeepthoughtModel:
def __init__(self):
self.model_name = "ruliad/deepthought-8b-llama-v0.01-alpha"
print(f"Loading model: {self.model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
add_bos_token=False,
trust_remote_code=True,
padding="left",
torch_dtype=torch.bfloat16,
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation=("flash_attention_2" if flash_attn_exists else "default"),
use_cache=True,
trust_remote_code=True,
)
# Helper method to generate the initial prompt
def _get_initial_prompt(
self, query: str, system_message: str = None
) -> str:
'''Helper method to generate the initial prompt format.'''
if system_message is None:
system_message = '''You are a superintelligent AI system, capable of comprehensive reasoning. When provided with <reasoning>, you must provide your logical reasoning chain to solve the user query. Be verbose with your outputs.'''
return f'''<|im_start|>system
{system_message}<|im_end|>
<|im_start|>user
{query}<|im_end|>
<|im_start|>reasoning
<reasoning>
[
{{
"step": 1,
"type": "problem_understanding",
"thought": "'''
# Method to generate reasoning given the prompt
def generate_reasoning(self, query: str, system_message: str = None) -> dict:
print('Generating reasoning...')
# Get and print prompt
prompt = self._get_initial_prompt(query, system_message)
print(prompt, end='')
# Tokenize the prompt
inputs = self.tokenizer(prompt, return_tensors='pt').input_ids.to(self.model.device)
try:
# Generate and stream reasoning
outputs = self.model.generate(
input_ids=inputs,
max_new_tokens=800,
do_sample=True,
temperature=0.2,
top_k=200,
top_p=1.0,
eos_token_id=self.tokenizer.eos_token_id,
streamer=TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True),
)
# Get the reasoning string
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return {
'raw_output': generated_text,
'success': True,
'error': None,
'initial_prompt': prompt,
}
except Exception as e:
logging.error(f'Error during generation: {e}')
return {
'raw_output': None,
'success': False,
'error': str(e),
'initial_prompt': None,
}
# Method to generate the final output
def generate_final_output(self, reasoning_output: dict) -> dict:
# Get the reasoning text and create the full prompt for the final output
reasoning_text = reasoning_output['raw_output'].replace(reasoning_output['initial_prompt'], '')
full_prompt = f'''{reasoning_text}<|im_end|>
<|im_start|>assistant
'''
print('Generating final response...')
# Tokenize the full prompt
inputs = self.tokenizer(full_prompt, return_tensors='pt').input_ids.to(self.model.device)
try:
# Generate and stream the final output
_ = self.model.generate(
input_ids=inputs,
max_new_tokens=400,
do_sample=True,
temperature=0.1,
top_k=50,
top_p=0.9,
eos_token_id=self.tokenizer.eos_token_id,
streamer=TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
)
return {'success': True, 'error': None}
except Exception as e:
logging.error(f'Error during final generation: {e}')
return {'success': False, 'error': str(e)}
def main():
model = DeepthoughtModel()
# Test queries
queries = [
"We want you to tell us the answer to life, the universe and everything. We'd really like an answer, something simple.",
]
# Process each query at a time (because we are streaming)
for query in queries:
print(f'\nProcessing query: {query}')
print('='*50)
# Reasoning
reasoning_result = model.generate_reasoning(query)
if not reasoning_result['success']:
print(f'\nError in reasoning: {reasoning_result["error"]}')
print('='*50)
continue
print('-'*50)
# Final output
final_result = model.generate_final_output(reasoning_result)
if not final_result['success']:
print(f'\nError in final generation: {final_result["error"]}')
print('='*50)
if __name__ == '__main__':
main()
|