File size: 1,879 Bytes
7362797
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.

import abc
from dataclasses import dataclass
from typing import Generator

import PIL.Image

# images, joined retrieval queries, retrieval images
MixedTokenType = str | PIL.Image.Image
MixedSequenceType = list[MixedTokenType]


@dataclass
class StreamingImage:
    image: PIL.Image.Image
    final: bool


DEFAULT_MULTIMODAL_CFG_IMAGE = 1.2
DEFAULT_MULTIMODAL_CFG_TEXT = 3.0
DEFAULT_IMAGE_CFG_IMAGE = 3.0
DEFAULT_IMAGE_CFG_TEXT = 3.0


class AbstractMultimodalGenerator(abc.ABC):
    @abc.abstractmethod
    def generate_text_streaming(
        self,
        prompts: list[MixedSequenceType],
        temp: float = 1.0,
        top_p: float = 0.8,
        seed: int | None = None,
    ) -> Generator[list[str], None, None]:
        pass

    @abc.abstractmethod
    def generate_image_streaming(
        self,
        prompt: MixedSequenceType,
        temp: float = 1.0,
        top_p: float = 0.8,
        cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE,
        cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT,
        yield_every_n: int = 32,
        seed: int | None = None,
    ) -> Generator[PIL.Image.Image, None, None]:
        pass

    @abc.abstractmethod
    def generate_multimodal_streaming(
        self,
        prompt: MixedSequenceType,
        temp: float = 1.0,
        top_p: float = 0.8,
        cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE,
        cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT,
        yield_every_n: int = 32,
        max_gen_tokens: int = 4096,
        repetition_penalty: float = 1.2,
        suffix_tokens: list[str] | None = None,
        seed: int | None = None,
    ) -> Generator[MixedSequenceType, None, None]:
        pass