|
from PIL import Image |
|
import io |
|
import streamlit as st |
|
import google.generativeai as genai |
|
|
|
safety_settings = [ |
|
{ |
|
"category": "HARM_CATEGORY_HARASSMENT", |
|
"threshold": "BLOCK_NONE" |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_HATE_SPEECH", |
|
"threshold": "BLOCK_NONE" |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", |
|
"threshold": "BLOCK_NONE" |
|
}, |
|
{ |
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", |
|
"threshold": "BLOCK_NONE" |
|
}, |
|
] |
|
|
|
|
|
password_placeholder = st.empty() |
|
password = password_placeholder.text_input("пасскод", type="password") |
|
if password == st.secrets["real_password"]: |
|
password_placeholder.empty() |
|
|
|
|
|
with st.sidebar: |
|
st.title("Gemini Pro") |
|
|
|
CONFIG = { |
|
"temperature": 0.5, |
|
"top_p": 1, |
|
"top_k": 32, |
|
"max_output_tokens": 4096, |
|
} |
|
|
|
genai.configure(api_key=st.secrets["api_key"]) |
|
|
|
uploaded_image = st.file_uploader( |
|
label="загрузи изображение", |
|
label_visibility="visible", |
|
help="если загружено изображение - можно спрашивать по нему что-то, если нет - будет обычный чат", |
|
accept_multiple_files=False, |
|
type=["png", "jpg"], |
|
) |
|
|
|
if uploaded_image: |
|
image_bytes = uploaded_image.read() |
|
|
|
|
|
def get_response(messages, model="gemini-pro"): |
|
try: |
|
model = genai.GenerativeModel(model, generation_config=genai.GenerationConfig(candidate_count=1, max_output_tokens=4096, temperature=0.6)) |
|
res = model.generate_content(messages, stream=True, safety_settings=safety_settings) |
|
return res |
|
except: |
|
return "Извини, но запрос не прошел цензуру." |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state["messages"] = [] |
|
messages = st.session_state["messages"] |
|
|
|
if messages: |
|
for item in messages: |
|
role, parts = item.values() |
|
if role == "user": |
|
st.chat_message("user").markdown(parts[0]) |
|
elif role == "model": |
|
st.chat_message("assistant").markdown(parts[0]) |
|
|
|
chat_message = st.chat_input("Спроси что-нибудь!") |
|
|
|
if chat_message: |
|
st.chat_message("user").markdown(chat_message) |
|
res_area = st.chat_message("assistant").empty() |
|
|
|
if "image_bytes" in globals(): |
|
vision_message = [chat_message, Image.open(io.BytesIO(image_bytes))] |
|
res = get_response(vision_message, model="gemini-pro-vision") |
|
else: |
|
vision_message = [{"role": "user", "parts": [chat_message]}] |
|
res = get_response(vision_message) |
|
|
|
res_text = "" |
|
try: |
|
for chunk in res: |
|
res_text += chunk.text |
|
res_area.markdown(res_text) |
|
except: |
|
res_text += f"запрос не прошел цензуру:\n{str(res.prompt_feedback)}" |
|
res_area.markdown(res_text) |
|
|
|
|
|
messages.append({"role": "model", "parts": [res_text]}) |
|
else: |
|
st.warning("неправильный пароль, увы...") |
|
|