llinahosna commited on
Commit
34dc018
1 Parent(s): 36cabf5

Create dall_e.py

Browse files
Files changed (1) hide show
  1. dall_e.py +59 -0
dall_e.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, getpass
2
+
3
+ import numpy as np
4
+ import replicate
5
+
6
+
7
+ class DalleImageGenerator:
8
+ """Dall-e model using Replicate API"""
9
+ def __init__(self, token=None):
10
+ if "REPLICATE_API_TOKEN" not in os.environ:
11
+ if token is not None:
12
+ os.environ["REPLICATE_API_TOKEN"] = token
13
+ else:
14
+ print(f"Please go to https://replicate.com/docs/api for your Replicate API token.")
15
+ os.environ["REPLICATE_API_TOKEN"] = getpass.getpass(f"Input Replicate API Token:")
16
+
17
+ self.dalle = replicate.models.get("kuprel/min-dalle")
18
+
19
+ def generate_images(self, text, grid_size, text_adherence=2):
20
+ urls = self.dalle.predict(text=text, grid_size=grid_size, log2_supercondition_factor=text_adherence)
21
+ images = get_image(list(urls)[-1])
22
+ h, w = images.shape[:2]
23
+ h, w = h // grid_size, w // grid_size
24
+ return blockshaped(images, h, w)
25
+
26
+
27
+ def get_image(url):
28
+ """download image from a url"""
29
+ from urllib.request import Request, urlopen
30
+ import io
31
+ import PIL.Image as Image
32
+ hdr = {
33
+ 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.11 (KHTML, like Gecko) Chrome/23.0.1271.64 Safari/537.11',
34
+ 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
35
+ 'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.3',
36
+ 'Accept-Encoding': 'none',
37
+ 'Accept-Language': 'en-US,en;q=0.8',
38
+ 'Connection': 'keep-alive'}
39
+
40
+ # urllib.request.urlretrieve(url, f"local-filename.jpg")
41
+ req = Request(url, headers=hdr)
42
+ page = urlopen(req)
43
+ return np.array(Image.open(io.BytesIO(page.read())))
44
+
45
+
46
+ def blockshaped(arr, nrows, ncols):
47
+ """
48
+ Return an array of shape (n, nrows, ncols) where
49
+ n * nrows * ncols = arr.size
50
+
51
+ If arr is a 2D array, the returned array should look like n subblocks with
52
+ each subblock preserving the "physical" layout of arr.
53
+ """
54
+ h, w, c = arr.shape
55
+ assert h % nrows == 0, f"{h} rows is not evenly divisible by {nrows}"
56
+ assert w % ncols == 0, f"{w} cols is not evenly divisible by {ncols}"
57
+ return (arr.reshape(h//nrows, nrows, w//ncols, ncols, - 1)
58
+ .swapaxes(1,2)
59
+ .reshape(-1, nrows, ncols, 3))