eaglelandsonce's picture
Update app.py
a0ba2d3 verified
raw
history blame
1.78 kB
# Import necessary libraries
import streamlit as st
from transformers import AutoModelForCausalLM
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
import torch
# Path to the mistral models
mistral_models_path = "MISTRAL_MODELS_PATH"
# Load the tokenizer
tokenizer = MistralTokenizer.v3()
# Load the model
model = AutoModelForCausalLM.from_pretrained("mistralai/Codestral-22B-v0.1")
model.to("cuda")
# Function to generate text
def generate_text(prompt):
# Encode the prompt
completion_request = ChatCompletionRequest(messages=[UserMessage(content=prompt)])
tokens = tokenizer.encode_chat_completion(completion_request).tokens
# Generate text using the model
with torch.no_grad():
generated_ids = model.generate(torch.tensor([tokens]).to(model.device), max_new_tokens=1000, do_sample=True)
# Decode the generated text
result = tokenizer.decode(generated_ids[0].tolist())
return result
# Streamlit interface
st.title("Codestral Text Generation")
st.write("""
This is a text generation application using the Codestral model from Mistral AI.
Enter your prompt below and generate text.
""")
# User input
user_input = st.text_area("Enter your prompt here:", "")
if st.button("Generate"):
if user_input:
with st.spinner("Generating text..."):
# Generate text using the model
generated_text = generate_text(user_input)
st.write("### Generated Text")
st.write(generated_text)
else:
st.warning("Please enter a prompt to generate text.")
# Run the Streamlit app
if __name__ == '__main__':
st.run()