Inference


# Load model directly
from transformers import AutoModelForCausalLM, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("suriya7/conversational-gpt-1")
model = AutoModelForCausalLM.from_pretrained("suriya7/conversational-gpt-1")

Chatting


import torch

prompt = """
<|im_start|>system\nYou are a helpful AI assistant named Securitron, trained by Aquilax.<|im_end|>
"""

# Keep a list for the last one conversation exchanges
conversation_history = []

while True:
    user_prompt = input("User Question: ")
    if user_prompt.lower() == 'break':
        break

    # Format the user's input
    user = f"""<|im_start|>user
{user_prompt}<|im_end|>"""

    # Add the user's question to the conversation history
    conversation_history.append(user)

    # Ensure conversation starts with a user's input and keep only the last 2 exchanges (4 turns)
    conversation_history = conversation_history[-5:]

    # Build the full prompt
    current_prompt = prompt + "\n".join(conversation_history)

    # Tokenize the prompt
    encodeds = tokenizer(current_prompt, return_tensors="pt", truncation=True).input_ids

    # Move model and inputs to the appropriate device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    inputs = encodeds.to(device)

    # Create an empty list to store generated tokens
    generated_ids = inputs

    # Start generating tokens one by one
    assistant_response = ""
    # print("Assistant: ", end="", flush=True)  # Print "Assistant:" once before streaming starts
    for _ in range(512):  # Specify a max token limit for streaming
        # Generate the next token in the sequence
        next_token = model.generate(
            generated_ids,
            max_new_tokens=1,
            pad_token_id=50259,
            eos_token_id=50259,
            num_return_sequences=1,
            do_sample=True,  # Use sampling for more diverse responses
            top_k=50,        # Limit to the top-k tokens to sample from
            temperature=0.7, # Adjust temperature for randomness
            top_p =0.90
        )
        
        # Add the generated token to the list
        generated_ids = torch.cat([generated_ids, next_token[:, -1:]], dim=1)
        
        # Decode the generated token (flatten it to a list of IDs)
        token_id = next_token[0, -1].item()  # Extract the last token as an integer
        token = tokenizer.decode([token_id], skip_special_tokens=True)

        
        # Append the token to the ongoing response
        assistant_response += token
        print(token, end="", flush=True)  # Stream the token in real time

        # If EOS token is encountered, stop generating
        if token_id == 50259:  # EOS token
            break

    print()  # Print a newline after streaming is complete

    # Add the assistant's response to the conversation history
    conversation_history.append(f"<|im_start|>{assistant_response.strip()}<|im_end|>")
Downloads last month
354
Safetensors
Model size
774M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for suriya7/conversational-gpt-V1

Quantizations
1 model