Spaces:
Sleeping
Sleeping
File size: 4,894 Bytes
a1ca2de 3332aa4 a1ca2de 3332aa4 a1ca2de 3332aa4 a1ca2de 3332aa4 a1ca2de 3332aa4 a1ca2de 936d161 a1ca2de 936d161 a1ca2de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
import toml
from pathlib import Path
import google.generativeai as palm_api
from pingpong import PingPong
from pingpong.pingpong import PPManager
from pingpong.pingpong import PromptFmt
from pingpong.pingpong import UIFmt
from pingpong.gradio import GradioChatUIFmt
from .utils import set_palm_api_key
# Set PaLM API Key
set_palm_api_key()
# Load PaLM Prompt Templates
palm_prompts = toml.load(Path('.') / 'assets' / 'palm_prompts.toml')
class PaLMChatPromptFmt(PromptFmt):
@classmethod
def ctx(cls, context):
warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
@classmethod
def prompt(cls, pingpong, truncate_size):
warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
ping = pingpong.ping[:truncate_size]
pong = pingpong.pong
if pong is None or pong.strip() == "":
return [
{
"author": "USER",
"content": ping
},
]
else:
pong = pong[:truncate_size]
return [
{
"author": "USER",
"content": ping
},
{
"author": "AI",
"content": pong
},
]
class PaLMChatPPManager(PPManager):
def build_prompts(self, from_idx: int=0, to_idx: int=-1, fmt: PromptFmt=PaLMChatPromptFmt, truncate_size: int=None):
warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
results = []
if to_idx == -1 or to_idx >= len(self.pingpongs):
to_idx = len(self.pingpongs)
for idx, pingpong in enumerate(self.pingpongs[from_idx:to_idx]):
results += fmt.prompt(pingpong, truncate_size=truncate_size)
return results
class GradioPaLMChatPPManager(PaLMChatPPManager):
def build_uis(self, from_idx: int=0, to_idx: int=-1, fmt: UIFmt=GradioChatUIFmt):
warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
if to_idx == -1 or to_idx >= len(self.pingpongs):
to_idx = len(self.pingpongs)
results = []
for pingpong in self.pingpongs[from_idx:to_idx]:
results.append(fmt.ui(pingpong))
return results
async def gen_text(
prompt,
mode="chat", #chat or text
parameters=None,
use_filter=True
):
warnings.warn("The 'palmchat' is deprecated and will not be supported in future versions.", DeprecationWarning, stacklevel=2)
if parameters is None:
temperature = 1.0
top_k = 40
top_p = 0.95
max_output_tokens = 1024
# default safety settings
safety_settings = [{"category":"HARM_CATEGORY_DEROGATORY","threshold":1},
{"category":"HARM_CATEGORY_TOXICITY","threshold":1},
{"category":"HARM_CATEGORY_VIOLENCE","threshold":2},
{"category":"HARM_CATEGORY_SEXUAL","threshold":2},
{"category":"HARM_CATEGORY_MEDICAL","threshold":2},
{"category":"HARM_CATEGORY_DANGEROUS","threshold":2}]
if not use_filter:
for idx, _ in enumerate(safety_settings):
safety_settings[idx]['threshold'] = 4
if mode == "chat":
parameters = {
'model': 'models/chat-bison-001',
'candidate_count': 1,
'context': "",
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'safety_settings': safety_settings,
}
else:
parameters = {
'model': 'models/text-bison-001',
'candidate_count': 1,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'max_output_tokens': max_output_tokens,
'safety_settings': safety_settings,
}
try:
if mode == "chat":
response = await palm_api.chat_async(**parameters, messages=prompt)
else:
response = palm_api.generate_text(**parameters, prompt=prompt)
except:
raise EnvironmentError("PaLM API is not available.")
if use_filter and len(response.filters) > 0:
raise Exception("PaLM API has withheld a response due to content safety concerns.")
else:
if mode == "chat":
response_txt = response.last
else:
response_txt = response.result
return response, response_txt |