File size: 1,634 Bytes
80d3afc
ace8c2c
be95987
80d3afc
 
 
824f973
25dfad2
 
 
ace8c2c
 
 
 
 
cec3f4d
ace8c2c
 
360f345
7de5168
94ff692
360f345
8e569e2
360f345
 
94ff692
360f345
 
 
336cc86
b50ab7b
cec3f4d
80d3afc
360f345
30a2066
 
 
80d3afc
5b0d141
 
3416cce
80d3afc
 
768bc38
80d3afc
c4615f5
3416cce
80d3afc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr
import json, os, vertexai, wandb

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

credentials = os.environ["CREDENTIALS"]
project = os.environ["PROJECT"]
wandb_api_key = os.environ["WANDB_API_KEY"]

config = {
    "model": "text-bison@001",
}

wandb.login(key = wandb_api_key)
wandb.init(project = "vertex-ai-llm", config = config)
config = wandb.config

credentials = json.loads(credentials)

from google.oauth2 import service_account
credentials = service_account.Credentials.from_service_account_info(credentials)

if credentials.expired:
    credentials.refresh(Request())

vertexai.init(project = project, 
              location = "us-central1",
              credentials = credentials)

from vertexai.language_models import TextGenerationModel
generation_model = TextGenerationModel.from_pretrained(config.model)

def invoke(prompt):
    completion = generation_model.predict(prompt = prompt).text
    wandb.log({"prompt": prompt, "completion": completion})
    return completion

description = """<a href='https://www.gradio.app/'>Gradio</a> UI using <a href='https://cloud.google.com/vertex-ai?hl=en/'>Google Vertex AI</a> API 
                 with Bison foundation model. Model performance evaluation via <a href='https://wandb.ai/bstraehle'>Weights & Biases</a>."""

gr.close_all()
demo = gr.Interface(fn=invoke, 
                    inputs = [gr.Textbox(label = "Prompt", lines = 1)],
                    outputs = [gr.Textbox(label = "Completion", lines = 1)],
                    title = "Generative AI - LLM",
                    description = description)
demo.launch()