|
|
|
import streamlit as st |
|
from transformers import AutoModelForCausalLM |
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer |
|
from mistral_common.protocol.instruct.messages import UserMessage |
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest |
|
import torch |
|
|
|
|
|
mistral_models_path = "MISTRAL_MODELS_PATH" |
|
|
|
|
|
tokenizer = MistralTokenizer.v3() |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("mistralai/Codestral-22B-v0.1") |
|
model.to("cuda") |
|
|
|
|
|
def generate_text(prompt): |
|
|
|
completion_request = ChatCompletionRequest(messages=[UserMessage(content=prompt)]) |
|
tokens = tokenizer.encode_chat_completion(completion_request).tokens |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate(torch.tensor([tokens]).to(model.device), max_new_tokens=1000, do_sample=True) |
|
|
|
|
|
result = tokenizer.decode(generated_ids[0].tolist()) |
|
return result |
|
|
|
|
|
st.title("Codestral Text Generation") |
|
|
|
st.write(""" |
|
This is a text generation application using the Codestral model from Mistral AI. |
|
Enter your prompt below and generate text. |
|
""") |
|
|
|
|
|
user_input = st.text_area("Enter your prompt here:", "") |
|
|
|
if st.button("Generate"): |
|
if user_input: |
|
with st.spinner("Generating text..."): |
|
|
|
generated_text = generate_text(user_input) |
|
st.write("### Generated Text") |
|
st.write(generated_text) |
|
else: |
|
st.warning("Please enter a prompt to generate text.") |
|
|
|
|
|
if __name__ == '__main__': |
|
st.run() |
|
|