File size: 2,221 Bytes
5c14a47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from PIL import Image
import io
import os
import streamlit as st
import google.generativeai as genai

# import google.ai.generativelanguage as glm


safety_settings = [
  {
    "category": "HARM_CATEGORY_HARASSMENT",
    "threshold": "BLOCK_NONE"
  },
  {
    "category": "HARM_CATEGORY_HATE_SPEECH",
    "threshold": "BLOCK_NONE"
  },
  {
    "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
    "threshold": "BLOCK_NONE"
  },
  {
    "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
    "threshold": "BLOCK_NONE"
  },
]

with st.sidebar:
    st.title("Gemini Pro")

    genai.configure(api_key=st.secrets["api_key"])


    # select_model = st.selectbox("модель", ["gemini-pro", "gemini-pro-vision"])

    # if select_model == "gemini-pro-vision":
    uploaded_image = st.file_uploader(
        "upload image",
        label_visibility="collapsed",
        accept_multiple_files=False,
        type=["png", "jpg"],
    )

    if uploaded_image:
        image_bytes = uploaded_image.read()


def get_response(messages, model="gemini-pro"):
    model = genai.GenerativeModel(model)
    res = model.generate_content(messages, stream=True, safety_settings=safety_settings)
    return res


if "messages" not in st.session_state:
    st.session_state["messages"] = []
messages = st.session_state["messages"]

if messages:
    for item in messages:
        role, parts = item.values()
        if role == "user":
            st.chat_message("user").markdown(parts[0])
        elif role == "model":
            st.chat_message("assistant").markdown(parts[0])

chat_message = st.chat_input("Спроси что-нибудь!")

if chat_message:
    st.chat_message("user").markdown(chat_message)
    res_area = st.chat_message("assistant").empty()

    if "image_bytes" in globals():
        vision_message = [chat_message, Image.open(io.BytesIO(image_bytes))]
        res = get_response(vision_message, model="gemini-pro-vision")
    else:
        vision_message = [{"role": "user", "parts": [chat_message]}]
        res = get_response(vision_message)


    res_text = ""
    for chunk in res:
        res_text += chunk.text
        res_area.markdown(res_text)

    messages.append({"role": "model", "parts": [res_text]})