File size: 1,487 Bytes
3a7347e
 
 
944743c
 
 
3171475
 
 
 
 
3a7347e
3171475
 
944743c
3171475
 
944743c
3171475
 
944743c
 
3171475
 
 
944743c
3171475
3a7347e
 
 
 
 
3171475
3a7347e
3171475
3a7347e
3171475
2be57ea
3a7347e
 
 
2be57ea
3a7347e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from transformers import TextIteratorStreamer
import threading
import spaces

print("Is CUDA available?", torch.cuda.is_available())

class ModelWrapper:
    def __init__(self):
        self.model = None  # Model will be loaded when GPU is allocated

    @spaces.GPU
    def generate(self, prompt):
        if self.model is None:
            # Explicitly set device_map to 'cuda'
            self.model = AutoGPTQForCausalLM.from_quantized(
                model_id,
                device_map={'': 'cuda:0'},
                trust_remote_code=True,
            )

        print("Model is on device:", next(self.model.parameters()).device)

        # Tokenize the input prompt
        inputs = tokenizer(prompt, return_tensors='pt').to('cuda')
        print("Inputs are on device:", inputs['input_ids'].device)

        # Set up the streamer
        streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

        # Prepare generation arguments
        generation_kwargs = dict(
            **inputs,
            streamer=streamer,
            do_sample=True,
            max_new_tokens=512,
        )

        # Start generation in a separate thread to enable streaming
        thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs)
        thread.start()

        # Yield generated text in real-time
        generated_text = ""
        for new_text in streamer:
            generated_text += new_text
            yield generated_text