File size: 594 Bytes
420fa8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import os

from dotenv import load_dotenv
from huggingface_hub import InferenceClient

load_dotenv()


def gemma_predict(combined_information, model_name):
    HF_token = os.environ["HF_TOKEN"]
    client = InferenceClient(model_name, token=HF_token)
    stream = client.text_generation(prompt=combined_information, details=True, stream=True, max_new_tokens=2048,
                                    return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text

    if "<eos>" in output:
        output = output.split("<eos>")[0]
    return output