Spaces:
Running
Running
File size: 4,459 Bytes
910dbfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from abc import ABC, abstractmethod
from typing import Type, TypeVar
import base64
import json
# constants
image_embed_prefix = "๐ผ๏ธ๐ "
log_to_console = False
def encode_image(image_data):
"""Generates a prefix for image base64 data in the required format for the
four known image formats: png, jpeg, gif, and webp.
Args:
image_data: The image data, encoded in base64.
Returns:
An object encoding the image
"""
# Get the first few bytes of the image data.
magic_number = image_data[:4]
# Check the magic number to determine the image type.
if magic_number.startswith(b'\x89PNG'):
image_type = 'png'
elif magic_number.startswith(b'\xFF\xD8'):
image_type = 'jpeg'
elif magic_number.startswith(b'GIF89a'):
image_type = 'gif'
elif magic_number.startswith(b'RIFF'):
if image_data[8:12] == b'WEBP':
image_type = 'webp'
else:
# Unknown image type.
raise Exception("Unknown image type")
else:
# Unknown image type.
raise Exception("Unknown image type")
return {"type": "base64",
"media_type": "image/" + image_type,
"data": base64.b64encode(image_data).decode('utf-8')}
LLMClass = TypeVar('LLMClass', bound='LLM')
class LLM(ABC):
@abstractmethod
def generate_body(message, history, system_prompt, temperature, max_tokens):
pass
@abstractmethod
def read_response(message, history, system_prompt, temperature, max_tokens):
pass
@staticmethod
def create_llm(model: str) -> Type[LLMClass]:
if model.startswith("anthropic.claude"):
return Claude()
elif model.startswith("mistral."):
return Mistral()
else:
raise ValueError(f"Unsupported model: {model}")
class Claude(LLM):
@staticmethod
def generate_body(message, history, system_prompt, temperature, max_tokens):
history_claude_format = []
user_msg_parts = []
for human, assi in history:
if human is not None:
if human.startswith(image_embed_prefix):
with open(human.lstrip(image_embed_prefix), mode="rb") as f:
content = f.read()
user_msg_parts.append({"type": "image",
"source": encode_image(content)})
else:
user_msg_parts.append({"type": "text", "text": human})
if assi is not None:
if user_msg_parts:
history_claude_format.append({"role": "user", "content": user_msg_parts})
user_msg_parts = []
history_claude_format.append({"role": "assistant", "content": assi})
if message:
user_msg_parts.append({"type": "text", "text": human})
if user_msg_parts:
history_claude_format.append({"role": "user", "content": user_msg_parts})
if log_to_console:
print(f"br_prompt: {str(history_claude_format)}")
body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"system": system_prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"messages": history_claude_format
})
return body
@staticmethod
def read_response(response_body) -> Type[str]:
return response_body.get('content')[0].get('text')
class Mistral(LLM):
@staticmethod
def generate_body(message, history, system_prompt, temperature, max_tokens):
prompt = "<s>"
for human, assi in history:
if prompt is not None:
prompt += f"[INST] {human} [/INST]\n"
if assi is not None:
prompt += f"{assi}</s>\n"
if message:
prompt += f"[INST] {message} [/INST]"
if log_to_console:
print(f"br_prompt: {str(prompt)}")
body = json.dumps({
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
})
return body
@staticmethod
def read_response(response_body) -> Type[str]:
return response_body.get('outputs')[0].get('text')
|