Spaces:
Sleeping
Sleeping
import torch | |
import os | |
import requests | |
import hf_transfer | |
import numpy as np | |
import io | |
from transformers import DynamicCache | |
import os | |
import spaces | |
import httpx | |
import tqdm | |
os.makedirs("tmp", exist_ok=True) | |
def generate_answer( | |
model, tokenizer, question_ids, cache, context_length, max_new_tokens | |
): | |
""" | |
Generate an answer to a question using greedy decoding. | |
Parameters: | |
model: Model instance | |
tokenizer: Tokenizer instance | |
question_ids (torch.Tensor): Tokenized question. | |
cache (DynamicCache): Key-value cache. | |
context_length (int): Length of the context. | |
max_new_tokens (int): Max number of tokens to generate. | |
Returns: | |
str: Generated answer. | |
""" | |
question_ids = question_ids.to("cuda") | |
cache_seq_lengths = [ | |
cache.get_seq_length(layer_idx) for layer_idx in range(len(cache)) | |
] | |
position_ids = torch.arange( | |
context_length, context_length + question_ids.shape[1], device=model.device | |
).unsqueeze(0) | |
outputs = model( | |
input_ids=question_ids.to(model.device), | |
past_key_values=cache, | |
position_ids=position_ids, | |
num_logits_to_keep=1, | |
) | |
position_ids = position_ids[:, -1:] + 1 | |
generated_ids = [outputs.logits[0, -1].argmax()] | |
for _ in range(max_new_tokens - 1): | |
outputs = model( | |
input_ids=generated_ids[-1].unsqueeze(0).unsqueeze(0), | |
past_key_values=cache, | |
position_ids=position_ids + _, | |
) | |
new_id = outputs.logits[0, -1].argmax() | |
generated_ids.append(new_id) | |
if new_id.item() == model.generation_config.eos_token_id: | |
break | |
answer = tokenizer.decode(torch.stack(generated_ids), skip_special_tokens=True) | |
cache.key_cache = [ | |
key[:, :, :c] for key, c in zip(cache.key_cache, cache_seq_lengths) | |
] | |
cache.value_cache = [ | |
value[:, :, :c] for value, c in zip(cache.value_cache, cache_seq_lengths) | |
] | |
return answer | |
def get_condense_kv_cache(context: str): | |
url = "https://ncs-client.condenses.ai/api/organic" | |
payload = { | |
"tier": "research", | |
"target_model": "mistralai/Mistral-7B-Instruct-v0.2", | |
"context": context, | |
"top_incentive": 0.1 | |
} | |
headers = { | |
"accept": "application/json", | |
"content-type": "application/json", | |
"user-api-key": os.getenv("CONDENSE_API_KEY"), | |
} | |
response = requests.post(url, json=payload, headers=headers).json() | |
print(response) | |
numpy_kv_cache, error = load_npy_from_url(response["compressed_kv_url"]) | |
if error: | |
print(error) | |
kv_cache = DynamicCache.from_legacy_cache( | |
torch.from_numpy(numpy_kv_cache).to("cuda").to(torch.bfloat16) | |
) | |
return kv_cache | |
def load_npy_from_url(url, max_size_mb=1024): | |
""" | |
Load a `.npy` file from a URL using hf_transfer. | |
Parameters: | |
url (str): URL of the `.npy` file. | |
max_size_mb (int): Max file size in megabytes. | |
Returns: | |
tuple: (Loaded NumPy array, Error message). | |
""" | |
try: | |
with httpx.Client() as client: | |
response = client.head(url) | |
if response.status_code != 200: | |
return None, f"Failed to fetch file info: HTTP {response.status_code}" | |
content_length = int(response.headers.get("content-length", 0)) | |
if content_length > max_size_mb * 1024 * 1024: | |
return None, f"File too large: {content_length / (1024 * 1024):.1f}MB exceeds {max_size_mb}MB limit" | |
filename = os.path.join("tmp", url.split("/")[-1]) | |
with tqdm(total=content_length, unit="B", unit_scale=True, desc="Downloading") as pbar: | |
hf_transfer.download( | |
url=url, filename=filename, chunk_size=1024 * 1024, callback=pbar.update | |
) | |
with open(filename, "rb") as f: | |
buffer = io.BytesIO(f.read()) | |
data = np.load(buffer) | |
os.remove(filename) | |
return data, "" | |
except Exception as e: | |
return None, str(e) |