Spaces:
Runtime error
Runtime error
from flask import Flask, request, jsonify | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
app = Flask(__name__) | |
# Load the model and tokenizer | |
model_name = "dicta-il/dictalm2.0-instruct" | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Ensure the tokenizer has a pad token, if not, add one. | |
if tokenizer.pad_token is None: | |
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
model.resize_token_embeddings(len(tokenizer)) | |
# Set the device to load the model onto | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
def chat(): | |
data = request.json | |
messages = data.get("messages", []) | |
if not messages: | |
return jsonify({"error": "No messages provided"}), 400 | |
# Combine messages into a single input string with the correct template | |
conversation = "<s>" | |
for i, message in enumerate(messages): | |
role = message["role"] | |
content = message["content"] | |
if role == "user": | |
if i == 0: | |
conversation += f"[INST] {content} [/INST]" | |
else: | |
conversation += f" [INST] {content} [/INST]" | |
elif role == "assistant": | |
conversation += f" {content}" | |
conversation += "</s>" | |
# Tokenize the combined conversation | |
encoded = tokenizer(conversation, return_tensors="pt").to(device) | |
# Generate response using the model | |
generated_ids = model.generate( | |
input_ids=encoded['input_ids'], | |
attention_mask=encoded['attention_mask'], | |
max_new_tokens=50, | |
pad_token_id=tokenizer.pad_token_id, | |
do_sample=True | |
) | |
decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
return jsonify({"response": decoded}) | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=5000) |