Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,339 Bytes
a6370a9 3a7347e 944743c b382e61 c2af8da b382e61 a6370a9 944743c 3171475 3a7347e 3171475 944743c 3171475 944743c 3171475 b382e61 944743c 3171475 944743c 3171475 3a7347e 3171475 3a7347e 3171475 3a7347e 3171475 2be57ea 3a7347e 2be57ea 3a7347e 407b0ed b382e61 407b0ed b382e61 407b0ed 77d5909 407b0ed b382e61 407b0ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import gradio as gr
import torch
import threading
import spaces
from transformers import AutoTokenizer, TextIteratorStreamer
from auto_gptq import AutoGPTQForCausalLM
# Model identifier
model_id = "jncraton/SmolLM2-1.7B-Instruct-ct2-int8"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, trust_remote_code=True)
print("Is CUDA available?", torch.cuda.is_available())
class ModelWrapper:
def __init__(self):
self.model = None # Model will be loaded when GPU is allocated
@spaces.GPU
def generate(self, prompt):
if self.model is None:
# Explicitly set device_map to 'cuda'
self.model = AutoGPTQForCausalLM.from_quantized(
model_id,
device_map={'': 'cuda:0'},
trust_remote_code=True,
)
self.model.eval()
print("Model is on device:", next(self.model.parameters()).device)
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors='pt').to('cuda')
print("Inputs are on device:", inputs['input_ids'].device)
# Set up the streamer
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
# Prepare generation arguments
generation_kwargs = dict(
**inputs,
streamer=streamer,
do_sample=True,
max_new_tokens=512,
)
# Start generation in a separate thread to enable streaming
thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# Yield generated text in real-time
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
# Instantiate the model wrapper
model_wrapper = ModelWrapper()
# Create the Gradio interface
interface = gr.Interface(
fn=model_wrapper.generate,
inputs=gr.Textbox(lines=5, label="Input Prompt"),
outputs=gr.Textbox(label="Generated Text", lines=10),
title="Mistral-Large-Instruct-2407 Text Completion",
description="Enter a prompt and receive a text completion using the Mistral-Large-Instruct-2407 INT4 model.",
allow_flagging='never',
live=False,
cache_examples=False
)
if __name__ == "__main__":
interface.launch() |