eaglelandsonce commited on
Commit
8fb3c01
·
verified ·
1 Parent(s): 1989cea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -16
app.py CHANGED
@@ -1,22 +1,37 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
3
 
4
- # Load model and tokenizer
5
- @st.cache_resource # Cache the resources to avoid reloading on every run
6
- def load_model():
7
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Codestral-22B-v0.1")
8
- model = AutoModelForCausalLM.from_pretrained("mistralai/Codestral-22B-v0.1")
9
- return tokenizer, model
10
 
11
- tokenizer, model = load_model()
12
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
13
 
14
- st.title("Chat with Codestral-22B")
15
- st.write("Ask a question and get a response from the Codestral-22B model.")
 
 
16
 
17
- user_input = st.text_input("You: ", "Type your question here...")
 
 
 
 
 
 
 
18
 
19
- if st.button("Send"):
20
- with st.spinner("Generating response..."):
21
- response = generator(user_input, max_length=100, num_return_sequences=1)
22
- st.write("Codestral-22B: " + response[0]["generated_text"])
 
1
  import streamlit as st
2
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
3
+ from mistral_common.protocol.instruct.messages import UserMessage
4
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
5
+ from mistral_inference.model import Transformer
6
+ from mistral_inference.generate import generate
7
+ from transformers import AutoModelForCausalLM
8
 
9
+ def main():
10
+ st.title("Codestral Inference with Hugging Face")
11
+
12
+ mistral_models_path = st.text_input("Enter the path to your Codestral model", "path/to/mistral_models/Codestral-22B-v0.1")
 
 
13
 
14
+ user_input = st.text_area("Enter your instruction", "Explain Machine Learning to me in a nutshell.")
15
+ max_tokens = st.slider("Max Tokens", min_value=10, max_value=500, value=64)
16
+ temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.7)
17
+
18
+ if st.button("Generate"):
19
+ with st.spinner("Generating response..."):
20
+ result = generate_response(user_input, mistral_models_path, max_tokens, temperature)
21
+ st.success("Response generated!")
22
+ st.text_area("Generated Response", result, height=200)
23
 
24
+ def generate_response(user_input, model_path, max_tokens, temperature):
25
+ tokenizer = MistralTokenizer.v3()
26
+ completion_request = ChatCompletionRequest(messages=[UserMessage(content=user_input)])
27
+ tokens = tokenizer.encode_chat_completion(completion_request).tokens
28
 
29
+ model = Transformer.from_folder(model_path)
30
+ out_tokens, _ = generate([tokens], model, max_tokens=max_tokens, temperature=temperature, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
31
+
32
+ result = tokenizer.decode(out_tokens[0])
33
+ return result
34
+
35
+ if __name__ == "__main__":
36
+ main()
37