File size: 2,592 Bytes
80d3afc ace8c2c be95987 80d3afc 7a49355 d0ccbf6 25dfad2 6919ace 3684d73 241921c 5c795c7 6919ace ace8c2c 360f345 7de5168 94ff692 360f345 8e569e2 360f345 94ff692 360f345 031aa40 7a49355 336cc86 241921c 563878a 80d3afc e75bad7 a36a82c e75bad7 360f345 f45a806 cf215da 22ebce8 cf215da 22a1dd5 cf215da a84efc1 22a1dd5 4cd17a7 cf215da 22a1dd5 66f39f2 cf215da aa69ed9 80d3afc 1a556a4 98e4ff7 3416cce 80d3afc cf215da 768bc38 80d3afc c4615f5 3416cce cf215da 80d3afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
import json, os, vertexai, wandb
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
credentials = os.environ["CREDENTIALS"]
project = os.environ["PROJECT"]
wandb_api_key = os.environ["WANDB_API_KEY"]
config = {
"max_output_tokens": 800,
"model": "gemini-pro",
"temperature": 0.1,
"top_k": 40,
"top_p": 1.0,
}
credentials = json.loads(credentials)
from google.oauth2 import service_account
credentials = service_account.Credentials.from_service_account_info(credentials)
if credentials.expired:
credentials.refresh(Request())
vertexai.init(project = project,
location = "us-central1",
credentials = credentials
)
from vertexai.preview.generative_models import GenerativeModel
generation_model = GenerativeModel(config["model"])
def wandb_log(prompt, completion):
wandb.login(key = wandb_api_key)
wandb.init(project = "vertex-ai-llm", config = config)
wandb.log({"prompt": str(prompt), "completion": str(completion)})
wandb.finish()
def invoke(prompt):
if (prompt == ""):
raise gr.Error("Prompt is required.")
completion = ""
try:
completion = generation_model.generate_content(prompt,
generation_config = {
"max_output_tokens": config["max_output_tokens"],
"temperature": config["temperature"],
"top_k": config["top_k"],
"top_p": config["top_p"],
})
if (completion.text != None):
completion = completion.text
except Exception as e:
completion = e
raise gr.Error(e)
finally:
wandb_log(prompt, completion)
return completion
description = """<a href='https://www.gradio.app/'>Gradio</a> UI using the <a href='https://cloud.google.com/vertex-ai'>Google Vertex AI</a> API
with Gemini Pro model. Evaluation via <a href='https://wandb.ai/bstraehle'>Weights & Biases</a>."""
gr.close_all()
demo = gr.Interface(fn = invoke,
inputs = [gr.Textbox(label = "Prompt", lines = 1)],
outputs = [gr.Textbox(label = "Completion", lines = 1)],
title = "Generative AI - LLM",
description = description)
demo.launch() |