File size: 734 Bytes
e483a77
 
7073eef
e483a77
 
 
e9b4eca
3d54a0f
e483a77
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import streamlit as st
from transformers import AutoTokenizer, LlamaForCausalLM
import torch

@st.cache(allow_output_mutation=True)
def load_model():
    model = LlamaForCausalLM.from_pretrained('/Weights')
    tokenizer = AutoTokenizer.from_pretrained('/code/Tokenizer/tokenizer.model')
    return model, tokenizer

model, tokenizer = load_model()

st.title("Chat with LlamaForCausalLM Model")

user_input = st.text_input("Type your message", "")

if st.button("Send"):
    inputs = tokenizer(user_input, return_tensors="pt")
    generate_ids = model.generate(inputs.input_ids, max_length=30)
    response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    st.write(response)