File size: 356 Bytes
70846c9
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
from transformers import pipeline

import config

def predict(prompt, model, tokenizer, max_length):
    pipe = pipeline(task = config.TASK,
                    model = model,
                    tokenizer = tokenizer,
                    max_length = max_length)
    result = pipe(f"<s>[INST] {prompt} [/INST]")
    return result[0]['generated_text']