Spaces:
Running
Running
import gradio as gr | |
import torch, os | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from huggingface_hub import login | |
def init_model(): | |
global model, tokenizer | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token is None: | |
raise ValueError("Hugging Face token is not set. Please set the HF_TOKEN environment variable.") | |
login(hf_token, add_to_git_credential=True) | |
tokenizer = AutoTokenizer.from_pretrained("IlyaGusev/saiga_gemma2_10b", token=hf_token) | |
# Настройка квантования | |
quant_config = BitsAndBytesConfig( | |
load_in_8bit=True, # Использование 8-битного квантования | |
llm_int8_threshold=6.0, # Порог активации для 8-битных весов | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
"IlyaGusev/saiga_gemma2_10b", | |
token=hf_token, | |
torch_dtype=torch.float16, # Использование float16 для уменьшения потребления памяти | |
device_map="auto", # Автоматическое распределение модели на GPU | |
quantization_config=quant_config, # Применение конфигурации квантования | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
def generate_response(prompt, max_length=100, temperature=0.7, top_p=0.85, repetition_penalty=1.1): | |
try: | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) | |
attention_mask = torch.ones_like(input_ids).to(model.device) | |
output = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
response_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return response_text | |
except Exception as e: | |
return f"Извините, произошла ошибка при генерации ответа: {str(e)}" | |
init_model() | |
iface = gr.Interface( | |
fn=generate_response, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Введите ваш текст здесь..."), | |
gr.Slider(20, 200, step=1, default=100, label="Максимальная длина"), | |
gr.Slider(0.1, 1.0, step=0.1, default=0.7, label="Температура"), | |
gr.Slider(0.1, 1.0, step=0.05, default=0.85, label="Top-p"), | |
gr.Slider(1.0, 2.0, step=0.1, default=1.1, label="Штраф за повторение") | |
], | |
outputs="text", | |
title="LLM Model Demo", | |
description="Введите текстовый запрос, чтобы сгенерировать ответ с помощью LLM модели." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |