File size: 776 Bytes
420fa8a
d5b3118
420fa8a
 
 
 
 
 
 
d5b3118
420fa8a
 
d5b3118
 
 
420fa8a
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os
from typing import Optional

from dotenv import load_dotenv
from huggingface_hub import InferenceClient

load_dotenv()


def gemma_predict(combined_information, model_name, config: Optional[dict]):
    HF_token = os.environ["HF_TOKEN"]
    client = InferenceClient(model_name, token=HF_token)
    stream = client.text_generation(prompt=combined_information, details=True, stream=True,
                                    max_new_tokens=config["max_output_tokens"],
                                    temperature=config["temperature"],
                                    return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text

    if "<eos>" in output:
        output = output.split("<eos>")[0]
    return output