Inference
# Load model directly
from transformers import AutoModelForCausalLM, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("suriya7/conversational-gpt-1")
model = AutoModelForCausalLM.from_pretrained("suriya7/conversational-gpt-1")
Chatting
import torch
prompt = """
<|im_start|>system\nYou are a helpful AI assistant named Securitron, trained by Aquilax.<|im_end|>
"""
# Keep a list for the last one conversation exchanges
conversation_history = []
while True:
user_prompt = input("User Question: ")
if user_prompt.lower() == 'break':
break
# Format the user's input
user = f"""<|im_start|>user
{user_prompt}<|im_end|>"""
# Add the user's question to the conversation history
conversation_history.append(user)
# Ensure conversation starts with a user's input and keep only the last 2 exchanges (4 turns)
conversation_history = conversation_history[-5:]
# Build the full prompt
current_prompt = prompt + "\n".join(conversation_history)
# Tokenize the prompt
encodeds = tokenizer(current_prompt, return_tensors="pt", truncation=True).input_ids
# Move model and inputs to the appropriate device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = encodeds.to(device)
# Create an empty list to store generated tokens
generated_ids = inputs
# Start generating tokens one by one
assistant_response = ""
# print("Assistant: ", end="", flush=True) # Print "Assistant:" once before streaming starts
for _ in range(512): # Specify a max token limit for streaming
# Generate the next token in the sequence
next_token = model.generate(
generated_ids,
max_new_tokens=1,
pad_token_id=50259,
eos_token_id=50259,
num_return_sequences=1,
do_sample=True, # Use sampling for more diverse responses
top_k=50, # Limit to the top-k tokens to sample from
temperature=0.7, # Adjust temperature for randomness
top_p =0.90
)
# Add the generated token to the list
generated_ids = torch.cat([generated_ids, next_token[:, -1:]], dim=1)
# Decode the generated token (flatten it to a list of IDs)
token_id = next_token[0, -1].item() # Extract the last token as an integer
token = tokenizer.decode([token_id], skip_special_tokens=True)
# Append the token to the ongoing response
assistant_response += token
print(token, end="", flush=True) # Stream the token in real time
# If EOS token is encountered, stop generating
if token_id == 50259: # EOS token
break
print() # Print a newline after streaming is complete
# Add the assistant's response to the conversation history
conversation_history.append(f"<|im_start|>{assistant_response.strip()}<|im_end|>")
- Downloads last month
- 354
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.