import os import subprocess from huggingface_hub import hf_hub_download, list_repo_files import gradio as gr from typing import Callable import base64 import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from threading import Thread from transformers import TextIteratorStreamer def get_fn(model_path: str, **model_kwargs): """Create a chat function with the specified model.""" # Initialize tokenizer and model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') quantization_config = BitsAndBytesConfig(load_in_8bit=True) tokenizer = AutoTokenizer.from_pretrained(model_path) # Simple flash-attention installation attempt try: subprocess.run( 'pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True, check=True ) # Try loading model with flash attention model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", quantization_config=quantization_config, attn_implementation="flash_attention_2", ) except Exception as e: print(f"Flash Attention failed, falling back to default attention: {str(e)}") # Fallback to default attention implementation model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype= torch.bfloat16 device_map="auto", quantization_config=quantization_config, ) def predict( message: str, history, system_prompt: str, temperature: float, max_new_tokens: int, top_k: int, repetition_penalty: float, top_p: float ): try: # Format conversation with ChatML format instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n' for user_msg, assistant_msg in history: instruction += f'<|im_start|>user\n{user_msg}\n<|im_end|>\n<|im_start|>assistant\n{assistant_msg}\n<|im_end|>\n' instruction += f'<|im_start|>user\n{message}\n<|im_end|>\n<|im_start|>assistant\n' streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) enc = tokenizer(instruction, return_tensors="pt", padding=True, truncation=True) input_ids, attention_mask = enc.input_ids, enc.attention_mask # Truncate if needed if input_ids.shape[1] > 8192: # Using n_ctx from original input_ids = input_ids[:, -8192:] attention_mask = attention_mask[:, -8192:] generate_kwargs = dict( input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), streamer=streamer, do_sample=True, temperature=temperature, max_new_tokens=max_new_tokens, top_k=top_k, repetition_penalty=repetition_penalty, top_p=top_p ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() response_text = "" for new_token in streamer: if new_token in ["<|endoftext|>", "<|im_end|>"]: break response_text += new_token yield response_text.strip() if not response_text.strip(): yield "I apologize, but I was unable to generate a response. Please try again." except Exception as e: print(f"Error during generation: {str(e)}") yield f"An error occurred: {str(e)}" return predict def get_image_base64(url: str, ext: str): with open(url, "rb") as image_file: encoded_string = base64.b64encode(image_file.read()).decode('utf-8') return "data:image/" + ext + ";base64," + encoded_string def handle_user_msg(message: str): if type(message) is str: return message elif type(message) is dict: if message["files"] is not None and len(message["files"]) > 0: ext = os.path.splitext(message["files"][-1])[1].strip(".") if ext.lower() in ["png", "jpg", "jpeg", "gif", "pdf"]: encoded_str = get_image_base64(message["files"][-1], ext) else: raise NotImplementedError(f"Not supported file type {ext}") content = [ {"type": "text", "text": message["text"]}, { "type": "image_url", "image_url": { "url": encoded_str, } }, ] else: content = message["text"] return content else: raise NotImplementedError def get_interface_args(pipeline): if pipeline == "chat": inputs = None outputs = None def preprocess(message, history): messages = [] files = None for user_msg, assistant_msg in history: if assistant_msg is not None: messages.append({"role": "user", "content": handle_user_msg(user_msg)}) messages.append({"role": "assistant", "content": assistant_msg}) else: files = user_msg if type(message) is str and files is not None: message = {"text":message, "files":files} elif type(message) is dict and files is not None: if message["files"] is None or len(message["files"]) == 0: message["files"] = files messages.append({"role": "user", "content": handle_user_msg(message)}) return {"messages": messages} postprocess = lambda x: x else: # Add other pipeline types when they will be needed raise ValueError(f"Unsupported pipeline type: {pipeline}") return inputs, outputs, preprocess, postprocess def get_pipeline(model_name): # Determine the pipeline type based on the model name # For simplicity, assuming all models are chat models at the moment return "chat" def get_model_path(name: str = None, model_path: str = None) -> str: """Get the local path to the model.""" if model_path: return model_path if name: if "/" in name: return name # Return HF model ID directly else: # You could maintain a mapping of friendly names to HF model IDs model_mapping = { # Add any default model mappings here "example-model": "organization/model-name" } if name not in model_mapping: raise ValueError(f"Unknown model name: {name}") return model_mapping[name] raise ValueError("Either name or model_path must be provided") def registry(name: str = None, model_path: str = None, **kwargs): """Create a Gradio Interface with similar styling and parameters.""" model_path = get_model_path(name, model_path) fn = get_fn(model_path, **kwargs) interface = gr.ChatInterface( fn=fn, additional_inputs_accordion=gr.Accordion("⚙️ Parameters", open=False), additional_inputs=[ gr.Textbox( "You are a helpful AI assistant.", label="System prompt" ), gr.Slider(0, 1, 0.7, label="Temperature"), gr.Slider(128, 4096, 1024, label="Max new tokens"), gr.Slider(1, 80, 40, label="Top K sampling"), gr.Slider(0, 2, 1.1, label="Repetition penalty"), gr.Slider(0, 1, 0.95, label="Top P sampling"), ], ) return interface