File size: 1,776 Bytes
389f6b6
5afaecc
a0ba2d3
 
 
 
 
444f78b
a0ba2d3
 
444f78b
a0ba2d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444f78b
 
389f6b6
444f78b
389f6b6
 
 
 
444f78b
b4ef92f
389f6b6
b4ef92f
389f6b6
 
 
 
a0ba2d3
389f6b6
a0ba2d3
389f6b6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Import necessary libraries
import streamlit as st
from transformers import AutoModelForCausalLM
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
import torch

# Path to the mistral models
mistral_models_path = "MISTRAL_MODELS_PATH"

# Load the tokenizer
tokenizer = MistralTokenizer.v3()

# Load the model
model = AutoModelForCausalLM.from_pretrained("mistralai/Codestral-22B-v0.1")
model.to("cuda")

# Function to generate text
def generate_text(prompt):
    # Encode the prompt
    completion_request = ChatCompletionRequest(messages=[UserMessage(content=prompt)])
    tokens = tokenizer.encode_chat_completion(completion_request).tokens

    # Generate text using the model
    with torch.no_grad():
        generated_ids = model.generate(torch.tensor([tokens]).to(model.device), max_new_tokens=1000, do_sample=True)

    # Decode the generated text
    result = tokenizer.decode(generated_ids[0].tolist())
    return result

# Streamlit interface
st.title("Codestral Text Generation")

st.write("""
This is a text generation application using the Codestral model from Mistral AI.
Enter your prompt below and generate text.
""")

# User input
user_input = st.text_area("Enter your prompt here:", "")

if st.button("Generate"):
    if user_input:
        with st.spinner("Generating text..."):
            # Generate text using the model
            generated_text = generate_text(user_input)
            st.write("### Generated Text")
            st.write(generated_text)
    else:
        st.warning("Please enter a prompt to generate text.")

# Run the Streamlit app
if __name__ == '__main__':
    st.run()