ag-mach commited on
Commit
bbdafb0
·
1 Parent(s): caf45a8

inputfield for model id

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -14,6 +14,7 @@ import torch
14
 
15
  gpt_model = 'gpt-4-1106-preview'
16
  embedding_model = 'text-embedding-3-small'
 
17
 
18
  def init():
19
  if "conversation" not in st.session_state:
@@ -21,10 +22,8 @@ def init():
21
  if "chat_history" not in st.session_state:
22
  st.session_state.chat_history = None
23
 
24
- def init_llm_pipeline():
25
- if "llm" not in st.session_state:
26
- model_id = "bigcode/starcoder2-7b"
27
-
28
  tokenizer = AutoTokenizer.from_pretrained(model_id)
29
  model = AutoModelForCausalLM.from_pretrained(
30
  model_id,
@@ -97,11 +96,13 @@ def main():
97
 
98
 
99
  with st.sidebar:
 
 
100
  st.subheader("Code Upload")
101
  upload_docs=st.file_uploader("Dokumente hier hochladen", accept_multiple_files=True)
102
  if st.button("Hochladen"):
103
  with st.spinner("Analysiere Dokumente ..."):
104
- init_llm_pipeline()
105
  raw_text = get_text(upload_docs)
106
  vectorstore = get_vectorstore(raw_text)
107
  st.session_state.conversation = get_conversation(vectorstore)
 
14
 
15
  gpt_model = 'gpt-4-1106-preview'
16
  embedding_model = 'text-embedding-3-small'
17
+ default_model_id = "bigcode/starcoder2-7b"
18
 
19
  def init():
20
  if "conversation" not in st.session_state:
 
22
  if "chat_history" not in st.session_state:
23
  st.session_state.chat_history = None
24
 
25
+ def init_llm_pipeline(model_id):
26
+ if "llm" not in st.session_state:
 
 
27
  tokenizer = AutoTokenizer.from_pretrained(model_id)
28
  model = AutoModelForCausalLM.from_pretrained(
29
  model_id,
 
96
 
97
 
98
  with st.sidebar:
99
+ st.subheader("Model selector")
100
+ model_id = st.text_input("Modelname on HuggingFace", default_model_id)
101
  st.subheader("Code Upload")
102
  upload_docs=st.file_uploader("Dokumente hier hochladen", accept_multiple_files=True)
103
  if st.button("Hochladen"):
104
  with st.spinner("Analysiere Dokumente ..."):
105
+ init_llm_pipeline(model_id)
106
  raw_text = get_text(upload_docs)
107
  vectorstore = get_vectorstore(raw_text)
108
  st.session_state.conversation = get_conversation(vectorstore)