File size: 2,286 Bytes
80d3afc
9622700
be95987
80d3afc
 
 
0c09b64
 
25dfad2
6919ace
85e8517
3c5b4aa
5c795c7
 
 
6919ace
ace8c2c
bf362b5
7de5168
0c09b64
 
8e569e2
0c09b64
 
94ff692
0c09b64
edf6a1c
7f25538
7a49355
336cc86
241921c
563878a
80d3afc
360f345
2790944
f45a806
cf215da
2b88ad0
edf6a1c
22ebce8
cf215da
22a1dd5
cf215da
 
 
 
 
5dfb35a
cf215da
a84efc1
 
22a1dd5
4cd17a7
cf215da
22a1dd5
cf215da
aa69ed9
80d3afc
5265e46
b739d45
3416cce
80d3afc
cf215da
 
c558024
f9dce98
3416cce
cf215da
80d3afc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import json, os, vertexai

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())

credentials = os.environ["CREDENTIALS"]
project = os.environ["PROJECT"]

config = {
    "max_output_tokens": 1000,
    "model": "gemini-1.5-pro-preview-0514",
    "temperature": 0.1,
    "top_k": 40,
    "top_p": 1.0,
}

credentials = json.loads(credentials)

from google.oauth2 import service_account
credentials = service_account.Credentials.from_service_account_info(credentials)

if credentials.expired:
    credentials.refresh(Request())

vertexai.init(location = "us-central1",
              credentials = credentials,
              project = project
             )

from vertexai.preview.generative_models import GenerativeModel
generation_model = GenerativeModel(config["model"])

def invoke(prompt):
    if not prompt:
        raise gr.Error("Prompt is required.")

    raise gr.Error("Please clone and bring your own credentials.")
    
    completion = ""
    
    try:
        completion = generation_model.generate_content(prompt,
                                                       generation_config = {
                                                           "max_output_tokens": config["max_output_tokens"],
                                                           "temperature": config["temperature"],
                                                           "top_k": config["top_k"],
                                                           "top_p": config["top_p"]})
        
        if (completion.text != None):
            completion = completion.text
    except Exception as e:
        completion = e
        
        raise gr.Error(e)
    
    return completion

description = """<a href='https://www.gradio.app/'>Gradio</a> UI using the <a href='https://cloud.google.com/vertex-ai'>Google Vertex AI</a> SDK 
                 with Gemini 1.5 Pro model."""

gr.close_all()

demo = gr.Interface(fn = invoke, 
                    inputs = [gr.Textbox(label = "Prompt", value = "If I dry one shirt in the sun, it takes 1 hour. How long do 3 shirts take?", lines = 1)],
                    outputs = [gr.Markdown(label = "Completion", value = os.environ["COMPLETION"])],
                    description = description)

demo.launch()