File size: 2,907 Bytes
80d3afc ace8c2c be95987 80d3afc 7a49355 d0ccbf6 25dfad2 6919ace 3684d73 241921c 5c795c7 6919ace ace8c2c cec3f4d ace8c2c 360f345 7de5168 94ff692 360f345 8e569e2 360f345 94ff692 360f345 031aa40 7a49355 336cc86 241921c 80d3afc 360f345 241921c cf04485 241921c 80d3afc 5b0d141 e3184f5 3416cce 80d3afc 768bc38 80d3afc c4615f5 3416cce 80d3afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
import json, os, vertexai, wandb
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
credentials = os.environ["CREDENTIALS"]
project = os.environ["PROJECT"]
wandb_api_key = os.environ["WANDB_API_KEY"]
config = {
"max_output_tokens": 800,
#"model": "text-bison@001",
"model": "gemini-pro",
"temperature": 0.1,
"top_k": 40,
"top_p": 1.0,
}
wandb.login(key = wandb_api_key)
wandb.init(project = "vertex-ai-llm", config = config)
config = wandb.config
credentials = json.loads(credentials)
from google.oauth2 import service_account
credentials = service_account.Credentials.from_service_account_info(credentials)
if credentials.expired:
credentials.refresh(Request())
vertexai.init(project = project,
location = "us-central1",
credentials = credentials
)
#from vertexai.language_models import TextGenerationModel
#generation_model = TextGenerationModel.from_pretrained(config.model)
from vertexai.preview.generative_models import GenerativeModel
generation_model = GenerativeModel(config.model)
def invoke(prompt):
#completion = generation_model.predict(prompt = prompt,
# max_output_tokens = config.max_output_tokens,
# temperature = config.temperature,
# top_k = config.top_k,
# top_p = config.top_p,
# ).text
#completion = generation_model.generate_content(prompt, generation_config = {
# "max_output_tokens": config.max_output_tokens,
# "temperature": config.temperature,
# "top_k": config.top_k,
# "top_p": config.top_p,
# }).text
#wandb.log({"prompt": prompt, "completion": completion})
#return completion
return "🛑 Execution is commented out, to view the source code see https://huggingface.co/spaces/bstraehle/google-vertex-ai-llm/tree/main."
description = """<a href='https://www.gradio.app/'>Gradio</a> UI using <a href='https://cloud.google.com/vertex-ai?hl=en/'>Google Vertex AI</a> API
with gemini-pro foundation model. Model performance evaluation via <a href='https://wandb.ai/bstraehle'>Weights & Biases</a>."""
gr.close_all()
demo = gr.Interface(fn=invoke,
inputs = [gr.Textbox(label = "Prompt", lines = 1)],
outputs = [gr.Textbox(label = "Completion", lines = 1)],
title = "Generative AI - LLM",
description = description)
demo.launch() |