eaglelandsonce commited on
Commit
a0ba2d3
·
verified ·
1 Parent(s): 389f6b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -9
app.py CHANGED
@@ -1,14 +1,34 @@
1
  # Import necessary libraries
2
  import streamlit as st
3
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
 
 
4
 
5
- # Load the model and tokenizer
6
- model_name = "mistralai/Codestral-22B-v0.1"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
- # Initialize the pipeline
11
- text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Streamlit interface
14
  st.title("Codestral Text Generation")
@@ -25,9 +45,9 @@ if st.button("Generate"):
25
  if user_input:
26
  with st.spinner("Generating text..."):
27
  # Generate text using the model
28
- generated_text = text_generator(user_input, max_length=100, num_return_sequences=1)
29
  st.write("### Generated Text")
30
- st.write(generated_text[0]['generated_text'])
31
  else:
32
  st.warning("Please enter a prompt to generate text.")
33
 
 
1
  # Import necessary libraries
2
  import streamlit as st
3
+ from transformers import AutoModelForCausalLM
4
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
5
+ from mistral_common.protocol.instruct.messages import UserMessage
6
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
7
+ import torch
8
 
9
+ # Path to the mistral models
10
+ mistral_models_path = "MISTRAL_MODELS_PATH"
 
 
11
 
12
+ # Load the tokenizer
13
+ tokenizer = MistralTokenizer.v3()
14
+
15
+ # Load the model
16
+ model = AutoModelForCausalLM.from_pretrained("mistralai/Codestral-22B-v0.1")
17
+ model.to("cuda")
18
+
19
+ # Function to generate text
20
+ def generate_text(prompt):
21
+ # Encode the prompt
22
+ completion_request = ChatCompletionRequest(messages=[UserMessage(content=prompt)])
23
+ tokens = tokenizer.encode_chat_completion(completion_request).tokens
24
+
25
+ # Generate text using the model
26
+ with torch.no_grad():
27
+ generated_ids = model.generate(torch.tensor([tokens]).to(model.device), max_new_tokens=1000, do_sample=True)
28
+
29
+ # Decode the generated text
30
+ result = tokenizer.decode(generated_ids[0].tolist())
31
+ return result
32
 
33
  # Streamlit interface
34
  st.title("Codestral Text Generation")
 
45
  if user_input:
46
  with st.spinner("Generating text..."):
47
  # Generate text using the model
48
+ generated_text = generate_text(user_input)
49
  st.write("### Generated Text")
50
+ st.write(generated_text)
51
  else:
52
  st.warning("Please enter a prompt to generate text.")
53