File size: 2,339 Bytes
a6370a9
3a7347e
 
944743c
 
b382e61
 
 
 
c2af8da
b382e61
 
 
a6370a9
944743c
3171475
 
 
 
 
3a7347e
3171475
 
944743c
3171475
 
944743c
3171475
 
b382e61
944743c
 
3171475
 
 
944743c
3171475
3a7347e
 
 
 
 
3171475
3a7347e
3171475
3a7347e
3171475
2be57ea
3a7347e
 
 
2be57ea
3a7347e
 
 
 
407b0ed
 
b382e61
407b0ed
 
b382e61
407b0ed
 
 
77d5909
407b0ed
 
 
 
b382e61
407b0ed
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import torch
import threading
import spaces

from transformers import AutoTokenizer, TextIteratorStreamer
from auto_gptq import AutoGPTQForCausalLM

# Model identifier
model_id = "jncraton/SmolLM2-1.7B-Instruct-ct2-int8"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, trust_remote_code=True)

print("Is CUDA available?", torch.cuda.is_available())

class ModelWrapper:
    def __init__(self):
        self.model = None  # Model will be loaded when GPU is allocated

    @spaces.GPU
    def generate(self, prompt):
        if self.model is None:
            # Explicitly set device_map to 'cuda'
            self.model = AutoGPTQForCausalLM.from_quantized(
                model_id,
                device_map={'': 'cuda:0'},
                trust_remote_code=True,
            )
            self.model.eval()

        print("Model is on device:", next(self.model.parameters()).device)

        # Tokenize the input prompt
        inputs = tokenizer(prompt, return_tensors='pt').to('cuda')
        print("Inputs are on device:", inputs['input_ids'].device)

        # Set up the streamer
        streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

        # Prepare generation arguments
        generation_kwargs = dict(
            **inputs,
            streamer=streamer,
            do_sample=True,
            max_new_tokens=512,
        )

        # Start generation in a separate thread to enable streaming
        thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        # Yield generated text in real-time
        generated_text = ""
        for new_text in streamer:
            generated_text += new_text
            yield generated_text

# Instantiate the model wrapper
model_wrapper = ModelWrapper()

# Create the Gradio interface
interface = gr.Interface(
    fn=model_wrapper.generate,
    inputs=gr.Textbox(lines=5, label="Input Prompt"),
    outputs=gr.Textbox(label="Generated Text", lines=10),
    title="Mistral-Large-Instruct-2407 Text Completion",
    description="Enter a prompt and receive a text completion using the Mistral-Large-Instruct-2407 INT4 model.",
    allow_flagging='never',
    live=False,
    cache_examples=False
)

if __name__ == "__main__":
    interface.launch()