Trellis / trellis /pipelines /samplers /classifier_free_guidance_mixin.py
crevelop's picture
feat: upload model files and gradio app
5e116f2 unverified
raw
history blame contribute delete
454 Bytes
from typing import *
class ClassifierFreeGuidanceSamplerMixin:
"""
A mixin class for samplers that apply classifier-free guidance.
"""
def _inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, **kwargs):
pred = super()._inference_model(model, x_t, t, cond, **kwargs)
neg_pred = super()._inference_model(model, x_t, t, neg_cond, **kwargs)
return (1 + cfg_strength) * pred - cfg_strength * neg_pred