File size: 1,988 Bytes
091596c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from utils.load_config import LoadConfig
APPCFG = LoadConfig()

app = Flask(__name__)

# Load the LLM and Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    APPCFG.llm_engine, token=APPCFG.gemma_token, device=APPCFG.device)
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="BioMistral/BioMistral-7B",
                                             token=APPCFG.gemma_token,
                                             torch_dtype=torch.float16,
                                             device_map=APPCFG.device
                                             )
app_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer
)

# Endpoint to generate text


@app.route("/generate_text", methods=["POST"])
def generate_text():
    data = request.json
    prompt = data.get("prompt", "")
    max_new_tokens = data.get("max_new_tokens", 1000)
    do_sample = data.get("do_sample", True)
    temperature = data.get("temperature", 0.1)
    top_k = data.get("top_k", 50)
    top_p = data.get("top_p", 0.95)
    tokenized_prompt = app_pipeline.tokenizer.apply_chat_template(
        prompt, tokenize=False, add_generation_prompt=True)

    # Generate text based on the prompt
    response = app_pipeline(
        tokenized_prompt,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p
    )
    print("==================")
    print("top_k:", top_k, "top_p:", top_p, "temperature:",
          temperature, "max_new_tokens:", max_new_tokens)
    print("==================")
    print(response[0]["generated_text"][len(tokenized_prompt):])
    print("==================")

    return jsonify({"response": response[0]["generated_text"][len(tokenized_prompt):]})


if __name__ == "__main__":
    app.run(debug=False, port=8888)