File size: 1,638 Bytes
80d3afc ace8c2c be95987 80d3afc 824f973 25dfad2 ace8c2c dbc76f4 ace8c2c 360f345 7de5168 94ff692 360f345 8e569e2 360f345 94ff692 360f345 336cc86 b50ab7b bfb5937 80d3afc 360f345 30a2066 80d3afc 5b0d141 3416cce 80d3afc 768bc38 80d3afc c4615f5 3416cce 80d3afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import gradio as gr
import json, os, vertexai, wandb
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
credentials = os.environ["CREDENTIALS"]
project = os.environ["PROJECT"]
wandb_api_key = os.environ["WANDB_API_KEY"]
config = {
"model": "text-bison@001",
}
wandb.login(key = wandb_api_key)
wandb.init(project = "vertex-ai-txt", config = config)
config = wandb.config
credentials = json.loads(credentials)
from google.oauth2 import service_account
credentials = service_account.Credentials.from_service_account_info(credentials)
if credentials.expired:
credentials.refresh(Request())
vertexai.init(project = project,
location = "us-central1",
credentials = credentials)
from vertexai.language_models import TextGenerationModel
generation_model = TextGenerationModel.from_pretrained("text-bison@001")
def invoke(prompt):
completion = generation_model.predict(prompt = prompt).text
wandb.log({"prompt": prompt, "completion": completion})
return completion
description = """<a href='https://www.gradio.app/'>Gradio</a> UI using <a href='https://cloud.google.com/vertex-ai?hl=en/'>Google Vertex AI</a> API
with Bison foundation model. Model performance evaluation via <a href='https://wandb.ai/bstraehle'>Weights & Biases</a>."""
gr.close_all()
demo = gr.Interface(fn=invoke,
inputs = [gr.Textbox(label = "Prompt", lines = 1)],
outputs = [gr.Textbox(label = "Completion", lines = 1)],
title = "Generative AI - LLM",
description = description)
demo.launch() |