File size: 2,915 Bytes
80d3afc
ace8c2c
be95987
80d3afc
 
 
7a49355
d0ccbf6
 
25dfad2
6919ace
3684d73
241921c
 
5c795c7
 
 
6919ace
ace8c2c
 
cec3f4d
ace8c2c
360f345
7de5168
94ff692
360f345
8e569e2
360f345
 
94ff692
360f345
 
031aa40
7a49355
336cc86
241921c
563878a
241921c
563878a
80d3afc
360f345
22f9400
563878a
 
 
 
22f9400
0a17b5c
563878a
 
 
 
0a17b5c
 
 
 
80d3afc
5b0d141
e3184f5
3416cce
80d3afc
 
768bc38
80d3afc
c4615f5
3416cce
80d3afc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
import json, os, vertexai, wandb

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

credentials = os.environ["CREDENTIALS"]
project = os.environ["PROJECT"]
wandb_api_key = os.environ["WANDB_API_KEY"]

config = {
    "max_output_tokens": 800,
    #"model": "text-bison@001",
    "model": "gemini-pro",
    "temperature": 0.1,
    "top_k": 40,
    "top_p": 1.0,
}

wandb.login(key = wandb_api_key)
wandb.init(project = "vertex-ai-llm", config = config)

credentials = json.loads(credentials)

from google.oauth2 import service_account
credentials = service_account.Credentials.from_service_account_info(credentials)

if credentials.expired:
    credentials.refresh(Request())

vertexai.init(project = project, 
              location = "us-central1",
              credentials = credentials
             )

#from vertexai.language_models import TextGenerationModel
#generation_model = TextGenerationModel.from_pretrained(config["model"])
from vertexai.preview.generative_models import GenerativeModel
generation_model = GenerativeModel(config["model"])

def invoke(prompt):
    #completion = generation_model.predict(prompt = prompt,
    #                                      max_output_tokens = config["max_output_tokens"],
    #                                      temperature = config["temperature"],
    #                                      top_k = config["top_k"],
    #                                      top_p = config["top_p"],
    #                                     ).text
    #completion = generation_model.generate_content(prompt, generation_config = {
    #                                                           "max_output_tokens": config["max_output_tokens"],
    #                                                           "temperature": config["temperature"],
    #                                                           "top_k": config["top_k"],
    #                                                           "top_p": config["top_p"],
    #                                                       }).text
    #wandb.log({"prompt": prompt, "completion": completion})
    #return completion
    return "🛑 Execution is commented out. To view the source code see https://huggingface.co/spaces/bstraehle/google-vertex-ai-llm/tree/main."

description = """<a href='https://www.gradio.app/'>Gradio</a> UI using <a href='https://cloud.google.com/vertex-ai?hl=en/'>Google Vertex AI</a> API 
                 with gemini-pro foundation model. Model performance evaluation via <a href='https://wandb.ai/bstraehle'>Weights & Biases</a>."""

gr.close_all()
demo = gr.Interface(fn=invoke, 
                    inputs = [gr.Textbox(label = "Prompt", lines = 1)],
                    outputs = [gr.Textbox(label = "Completion", lines = 1)],
                    title = "Generative AI - LLM",
                    description = description)
demo.launch()