Robin246 commited on
Commit
756023e
·
1 Parent(s): ca928eb

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -18
app.py CHANGED
@@ -1,28 +1,33 @@
1
  import streamlit as st
2
  from transformers import AutoModelForSeq2SeqLM, T5Tokenizer
3
 
4
- def generate_response(input_prompt):
5
- fine_tuned_model_path = "Robin246/inxai" # Update with your Hugging Face model path
6
-
7
- # Load the fine-tuned model and tokenizer
8
- fine_tuned_model = AutoModelForSeq2SeqLM.from_pretrained(fine_tuned_model_path)
9
- tokenizer = T5Tokenizer.from_pretrained(fine_tuned_model_path)
 
 
 
 
 
 
 
10
 
11
  input_text = f"Input prompt: {input_prompt}"
12
-
13
  input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=64, padding="max_length", truncation=True)
14
- output_ids = fine_tuned_model.generate(input_ids, max_length=256, num_return_sequences=1, num_beams=1, early_stopping=True)
15
  generated_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
16
-
17
- return generated_output
18
 
19
  def main():
20
- st.title('INXAI LLM model')
21
- user_input = st.text_input('Input Prompt:', '')
22
-
23
- if st.button('Generate'):
24
- response = generate_response(user_input)
25
- st.write('Generated Reply:', response)
26
-
27
- if __name__ == '__main__':
28
  main()
 
1
  import streamlit as st
2
  from transformers import AutoModelForSeq2SeqLM, T5Tokenizer
3
 
4
+ def generate_response(input_prompt, model_path):
5
+ if model_path == 'google/flan-t5-small':
6
+ model_name = 'Google Flan T5'
7
+ elif model_path == 'MBZUAI/LaMini-Flan-T5-77M':
8
+ model_name = 'Lamini Flan T5'
9
+ else:
10
+ model_name = 'INXAI'
11
+
12
+ fine_tuned_model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
13
+ if model_path == 'MBZUAI/LaMini-Flan-T5-77M':
14
+ tokenizer = T5Tokenizer.from_pretrained('t5-base')
15
+ else:
16
+ tokenizer = T5Tokenizer.from_pretrained(model_path)
17
 
18
  input_text = f"Input prompt: {input_prompt}"
 
19
  input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=64, padding="max_length", truncation=True)
20
+ output_ids = fine_tuned_model.generate(input_ids, max_length=256, num_return_sequences=1, num_beams=2, early_stopping=True)
21
  generated_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
22
+ return generated_output, model_name
 
23
 
24
  def main():
25
+ st.title("INXAI LLM Model")
26
+ model_selection = st.selectbox("Choose a model", ["google/flan-t5-small", "MBZUAI/LaMini-Flan-T5-77M", "Robin246/inxai"])
27
+ input_prompt = st.text_input("Enter input text")
28
+ if st.button("Generate"):
29
+ reply, model_name = generate_response(input_prompt, model_selection)
30
+ st.write(f"Generated Reply : {reply}")
31
+
32
+ if __name__ == "__main__":
33
  main()