import gradio as gr import json, os, vertexai, wandb from dotenv import load_dotenv, find_dotenv _ = load_dotenv(find_dotenv()) credentials = os.environ["CREDENTIALS"] project = os.environ["PROJECT"] wandb_api_key = os.environ["WANDB_API_KEY"] config = { "max_output_tokens": 800, #"model": "text-bison@001", "model": "gemini-pro", "temperature": 0.1, "top_k": 40, "top_p": 1.0, } wandb.login(key = wandb_api_key) wandb.init(project = "vertex-ai-llm", config = config) config = wandb.config credentials = json.loads(credentials) from google.oauth2 import service_account credentials = service_account.Credentials.from_service_account_info(credentials) if credentials.expired: credentials.refresh(Request()) vertexai.init(project = project, location = "us-central1", credentials = credentials ) #from vertexai.language_models import TextGenerationModel #generation_model = TextGenerationModel.from_pretrained(config.model) from vertexai.preview.generative_models import GenerativeModel generation_model = GenerativeModel(config.model) def invoke(prompt): #completion = generation_model.predict(prompt = prompt, # max_output_tokens = config.max_output_tokens, # temperature = config.temperature, # top_k = config.top_k, # top_p = config.top_p, # ).text #completion = generation_model.generate_content(prompt, generation_config = { # "max_output_tokens": config.max_output_tokens, # "temperature": config.temperature, # "top_k": config.top_k, # "top_p": config.top_p, # }).text #wandb.log({"prompt": prompt, "completion": completion}) #return completion return "🛑 Execution is commented out, to view the source code see https://huggingface.co/spaces/bstraehle/google-vertex-ai-llm/tree/main." description = """Gradio UI using Google Vertex AI API with gemini-pro foundation model. Model performance evaluation via Weights & Biases.""" gr.close_all() demo = gr.Interface(fn=invoke, inputs = [gr.Textbox(label = "Prompt", lines = 1)], outputs = [gr.Textbox(label = "Completion", lines = 1)], title = "Generative AI - LLM", description = description) demo.launch()