gmnpr / app.py
getapi's picture
Update app.py
0589f7a
raw
history blame
3.31 kB
from PIL import Image
import io
import streamlit as st
import google.generativeai as genai
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
},
]
password_placeholder = st.empty()
password = password_placeholder.text_input("пасскод", type="password")
if password == st.secrets["real_password"]:
password_placeholder.empty()
# st.success("тебе позволено войти, хорошо")
with st.sidebar:
st.title("Gemini Pro")
CONFIG = {
"temperature": 0.5,
"top_p": 1,
"top_k": 32,
"max_output_tokens": 4096,
}
genai.configure(api_key=st.secrets["api_key"])
uploaded_image = st.file_uploader(
label="загрузи изображение",
label_visibility="visible",
help="если загружено изображение - можно спрашивать по нему что-то, если нет - будет обычный чат",
accept_multiple_files=False,
type=["png", "jpg"],
)
if uploaded_image:
image_bytes = uploaded_image.read()
def get_response(messages, model="gemini-pro"):
try:
model = genai.GenerativeModel(model, generation_config=genai.GenerationConfig(candidate_count=1, max_output_tokens=4096, temperature=0.6))
res = model.generate_content(messages, stream=True, safety_settings=safety_settings)
return res
except:
return "Извини, но запрос не прошел цензуру."
if "messages" not in st.session_state:
st.session_state["messages"] = []
messages = st.session_state["messages"]
if messages:
for item in messages:
role, parts = item.values()
if role == "user":
st.chat_message("user").markdown(parts[0])
elif role == "model":
st.chat_message("assistant").markdown(parts[0])
chat_message = st.chat_input("Спроси что-нибудь!")
if chat_message:
st.chat_message("user").markdown(chat_message)
res_area = st.chat_message("assistant").empty()
if "image_bytes" in globals():
vision_message = [chat_message, Image.open(io.BytesIO(image_bytes))]
res = get_response(vision_message, model="gemini-pro-vision")
else:
vision_message = [{"role": "user", "parts": [chat_message]}]
res = get_response(vision_message)
res_text = ""
try:
for chunk in res:
res_text += chunk.text
res_area.markdown(res_text)
except:
res_text += f"запрос не прошел цензуру:\n{str(res.prompt_feedback)}"
res_area.markdown(res_text)
messages.append({"role": "model", "parts": [res_text]})
else:
st.warning("неправильный пароль, увы...")