import torch from transformers import TextIteratorStreamer import threading import spaces print("Is CUDA available?", torch.cuda.is_available()) class ModelWrapper: def __init__(self): self.model = None # Model will be loaded when GPU is allocated @spaces.GPU def generate(self, prompt): if self.model is None: # Explicitly set device_map to 'cuda' self.model = AutoGPTQForCausalLM.from_quantized( model_id, device_map={'': 'cuda:0'}, trust_remote_code=True, ) print("Model is on device:", next(self.model.parameters()).device) # Tokenize the input prompt inputs = tokenizer(prompt, return_tensors='pt').to('cuda') print("Inputs are on device:", inputs['input_ids'].device) # Set up the streamer streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) # Prepare generation arguments generation_kwargs = dict( **inputs, streamer=streamer, do_sample=True, max_new_tokens=512, ) # Start generation in a separate thread to enable streaming thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() # Yield generated text in real-time generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text