CryptoTrader-LM / run_inference.py
agarkovv's picture
Create run_inference.py
97fdf26 verified
raw
history blame
857 Bytes
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
import re
PROMPT = "YOUR PROMPT HERE"
MAX_LENGTH = 32768 # Do not change
DEVICE = "cuda"
model_id = "agarkovv/Ministral-8B-Instruct-2410-LoRA-trading"
base_model_id = "mistralai/Ministral-8B-Instruct-2410"
model = AutoPeftModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
model = model.to(DEVICE)
model.eval()
inputs = tokenizer(
PROMPT, return_tensors="pt", padding=False, max_length=MAX_LENGTH, truncation=True
)
inputs = {key: value.to(model.device) for key, value in inputs.items()}
res = model.generate(
**inputs,
use_cache=True,
max_new_tokens=MAX_LENGTH,
)
output = tokenizer.decode(res[0], skip_special_tokens=True)
answer = re.sub(r".*\[/INST\]\s*", "", output, flags=re.DOTALL)
print(answer)