-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinterpolation_steps.py
More file actions
92 lines (78 loc) · 2.62 KB
/
interpolation_steps.py
File metadata and controls
92 lines (78 loc) · 2.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import keras_cv
from tensorflow import keras
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import math
from PIL import Image
import argparse
import random
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt_1",
type=str,
nargs="?",
default="a monkey on acid in interstellar space",
help="the 1st prompt to render",
)
parser.add_argument(
"--prompt_2",
type=str,
nargs="?",
default="an octupus having a rave underwater",
help="the 2nd prompt to render",
)
parser.add_argument("--index", type=int, default=0,
help="Index of the filename")
parser.add_argument(
"--filename",
type=str,
nargs="?",
default="output",
help="filename without extension",
)
args = parser.parse_args()
print("use cmd-c to quit \n\n")
print(f"Begin image sequence {args.filename}{args.index}")
print(f"prompt 1 - {args.prompt_1}")
print(f"prompt 2 - {args.prompt_2}")
print(f"if you need to pause use ctrl-z, if you want to run in background after pause, type bg")
print("Good luck \n\n")
# Instantiate the Stable Diffusion model
model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)
# Funciton for creating gifs
def export_as_gif(filename, images, frames_per_second=10, rubber_band=False):
if rubber_band:
images += images[2:-1][::-1]
images[0].save(
filename,
save_all=True,
append_images=images[1:],
duration=1000 // frames_per_second,
loop=0,
)
prompt_1 = ' '.join(f"{args.prompt_1},vivid colors, high detail, 4k, breathtaking, psychedelic art, trending on art station, unreal engine".split(' ')[:50])
prompt_2 = ' '.join(f"{args.prompt_2},vivid colors, high detail, 4k, breathtaking, psychedelic art, trending on art station, unreal engine".split(' ')[:50])
encoding_1 = tf.squeeze(model.encode_text(prompt_1))
encoding_2 = tf.squeeze(model.encode_text(prompt_2))
seed = len(args.prompt_1) + len(args.prompt_2) + random.randint(0, 666)
noise = tf.random.normal((512 // 8, 512 // 8, 4), seed=seed)
# Next
interpolation_steps = 15
batch_size = 1
batches = interpolation_steps // batch_size
interpolated_encodings = tf.linspace(
encoding_1, encoding_2, interpolation_steps)
batched_encodings = tf.split(interpolated_encodings, batches)
images = []
for batch in range(batches):
images += [
Image.fromarray(img)
for img in model.generate_image(
batched_encodings[batch],
batch_size=batch_size,
num_steps=25,
diffusion_noise=noise,
)
]
export_as_gif(f"{args.filename}{args.index}.gif", images, rubber_band=False)