Spaces:
Runtime error
Runtime error
inputfield for model id
Browse files
app.py
CHANGED
@@ -14,6 +14,7 @@ import torch
|
|
14 |
|
15 |
gpt_model = 'gpt-4-1106-preview'
|
16 |
embedding_model = 'text-embedding-3-small'
|
|
|
17 |
|
18 |
def init():
|
19 |
if "conversation" not in st.session_state:
|
@@ -21,10 +22,8 @@ def init():
|
|
21 |
if "chat_history" not in st.session_state:
|
22 |
st.session_state.chat_history = None
|
23 |
|
24 |
-
def init_llm_pipeline():
|
25 |
-
if "llm" not in st.session_state:
|
26 |
-
model_id = "bigcode/starcoder2-7b"
|
27 |
-
|
28 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
29 |
model = AutoModelForCausalLM.from_pretrained(
|
30 |
model_id,
|
@@ -97,11 +96,13 @@ def main():
|
|
97 |
|
98 |
|
99 |
with st.sidebar:
|
|
|
|
|
100 |
st.subheader("Code Upload")
|
101 |
upload_docs=st.file_uploader("Dokumente hier hochladen", accept_multiple_files=True)
|
102 |
if st.button("Hochladen"):
|
103 |
with st.spinner("Analysiere Dokumente ..."):
|
104 |
-
init_llm_pipeline()
|
105 |
raw_text = get_text(upload_docs)
|
106 |
vectorstore = get_vectorstore(raw_text)
|
107 |
st.session_state.conversation = get_conversation(vectorstore)
|
|
|
14 |
|
15 |
gpt_model = 'gpt-4-1106-preview'
|
16 |
embedding_model = 'text-embedding-3-small'
|
17 |
+
default_model_id = "bigcode/starcoder2-7b"
|
18 |
|
19 |
def init():
|
20 |
if "conversation" not in st.session_state:
|
|
|
22 |
if "chat_history" not in st.session_state:
|
23 |
st.session_state.chat_history = None
|
24 |
|
25 |
+
def init_llm_pipeline(model_id):
|
26 |
+
if "llm" not in st.session_state:
|
|
|
|
|
27 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
28 |
model = AutoModelForCausalLM.from_pretrained(
|
29 |
model_id,
|
|
|
96 |
|
97 |
|
98 |
with st.sidebar:
|
99 |
+
st.subheader("Model selector")
|
100 |
+
model_id = st.text_input("Modelname on HuggingFace", default_model_id)
|
101 |
st.subheader("Code Upload")
|
102 |
upload_docs=st.file_uploader("Dokumente hier hochladen", accept_multiple_files=True)
|
103 |
if st.button("Hochladen"):
|
104 |
with st.spinner("Analysiere Dokumente ..."):
|
105 |
+
init_llm_pipeline(model_id)
|
106 |
raw_text = get_text(upload_docs)
|
107 |
vectorstore = get_vectorstore(raw_text)
|
108 |
st.session_state.conversation = get_conversation(vectorstore)
|