Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import json | |
from copy import deepcopy | |
from typing import TYPE_CHECKING, Optional, Union | |
import fastapi | |
from gradio_client.documentation import document, set_documentation_group | |
from gradio import utils | |
from gradio.data_classes import PredictBody | |
from gradio.exceptions import Error | |
from gradio.helpers import EventData | |
if TYPE_CHECKING: | |
from gradio.routes import App | |
set_documentation_group("routes") | |
class Obj: | |
""" | |
Using a class to convert dictionaries into objects. Used by the `Request` class. | |
Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/ | |
""" | |
def __init__(self, dict_): | |
self.__dict__.update(dict_) | |
for key, value in dict_.items(): | |
if isinstance(value, (dict, list)): | |
value = Obj(value) | |
setattr(self, key, value) | |
def __getitem__(self, item): | |
return self.__dict__[item] | |
def __setitem__(self, item, value): | |
self.__dict__[item] = value | |
def __iter__(self): | |
for key, value in self.__dict__.items(): | |
if isinstance(value, Obj): | |
yield (key, dict(value)) | |
else: | |
yield (key, value) | |
def __contains__(self, item) -> bool: | |
if item in self.__dict__: | |
return True | |
for value in self.__dict__.values(): | |
if isinstance(value, Obj) and item in value: | |
return True | |
return False | |
def keys(self): | |
return self.__dict__.keys() | |
def values(self): | |
return self.__dict__.values() | |
def items(self): | |
return self.__dict__.items() | |
def __str__(self) -> str: | |
return str(self.__dict__) | |
def __repr__(self) -> str: | |
return str(self.__dict__) | |
class Request: | |
""" | |
A Gradio request object that can be used to access the request headers, cookies, | |
query parameters and other information about the request from within the prediction | |
function. The class is a thin wrapper around the fastapi.Request class. Attributes | |
of this class include: `headers`, `client`, `query_params`, and `path_params`. If | |
auth is enabled, the `username` attribute can be used to get the logged in user. | |
Example: | |
import gradio as gr | |
def echo(name, request: gr.Request): | |
print("Request headers dictionary:", request.headers) | |
print("IP address:", request.client.host) | |
return name | |
io = gr.Interface(echo, "textbox", "textbox").launch() | |
""" | |
def __init__( | |
self, | |
request: fastapi.Request | None = None, | |
username: str | None = None, | |
**kwargs, | |
): | |
""" | |
Can be instantiated with either a fastapi.Request or by manually passing in | |
attributes (needed for websocket-based queueing). | |
Parameters: | |
request: A fastapi.Request | |
""" | |
self.request = request | |
self.username = username | |
self.kwargs: dict = kwargs | |
def dict_to_obj(self, d): | |
if isinstance(d, dict): | |
return json.loads(json.dumps(d), object_hook=Obj) | |
else: | |
return d | |
def __getattr__(self, name): | |
if self.request: | |
return self.dict_to_obj(getattr(self.request, name)) | |
else: | |
try: | |
obj = self.kwargs[name] | |
except KeyError as ke: | |
raise AttributeError( | |
f"'Request' object has no attribute '{name}'" | |
) from ke | |
return self.dict_to_obj(obj) | |
class FnIndexInferError(Exception): | |
pass | |
def infer_fn_index(app: App, api_name: str, body: PredictBody) -> int: | |
if body.fn_index is None: | |
for i, fn in enumerate(app.get_blocks().dependencies): | |
if fn["api_name"] == api_name: | |
return i | |
raise FnIndexInferError(f"Could not infer fn_index for api_name {api_name}.") | |
else: | |
return body.fn_index | |
def compile_gr_request( | |
app: App, | |
body: PredictBody, | |
fn_index_inferred: int, | |
username: Optional[str], | |
request: Optional[fastapi.Request], | |
): | |
# If this fn_index cancels jobs, then the only input we need is the | |
# current session hash | |
if app.get_blocks().dependencies[fn_index_inferred]["cancels"]: | |
body.data = [body.session_hash] | |
if body.request: | |
if body.batched: | |
gr_request = [Request(username=username, **req) for req in body.request] | |
else: | |
assert isinstance(body.request, dict) | |
gr_request = Request(username=username, **body.request) | |
else: | |
if request is None: | |
raise ValueError("request must be provided if body.request is None") | |
gr_request = Request(username=username, request=request) | |
return gr_request | |
def restore_session_state(app: App, body: PredictBody): | |
fn_index = body.fn_index | |
session_hash = getattr(body, "session_hash", None) | |
if session_hash is not None: | |
if session_hash not in app.state_holder: | |
app.state_holder[session_hash] = { | |
_id: deepcopy(getattr(block, "value", None)) | |
for _id, block in app.get_blocks().blocks.items() | |
if getattr(block, "stateful", False) | |
} | |
session_state = app.state_holder[session_hash] | |
# The should_reset set keeps track of the fn_indices | |
# that have been cancelled. When a job is cancelled, | |
# the /reset route will mark the jobs as having been reset. | |
# That way if the cancel job finishes BEFORE the job being cancelled | |
# the job being cancelled will not overwrite the state of the iterator. | |
if fn_index in app.iterators_to_reset[session_hash]: | |
iterators = {} | |
app.iterators_to_reset[session_hash].remove(fn_index) | |
else: | |
iterators = app.iterators[session_hash] | |
else: | |
session_state = {} | |
iterators = {} | |
return session_state, iterators | |
async def call_process_api( | |
app: App, | |
body: PredictBody, | |
gr_request: Union[Request, list[Request]], | |
fn_index_inferred, | |
): | |
session_state, iterators = restore_session_state(app=app, body=body) | |
dependency = app.get_blocks().dependencies[fn_index_inferred] | |
target = dependency["targets"][0] if len(dependency["targets"]) else None | |
event_data = EventData( | |
app.get_blocks().blocks.get(target) if target else None, | |
body.event_data, | |
) | |
event_id = getattr(body, "event_id", None) | |
fn_index = body.fn_index | |
session_hash = getattr(body, "session_hash", None) | |
inputs = body.data | |
batch_in_single_out = not body.batched and dependency["batch"] | |
if batch_in_single_out: | |
inputs = [inputs] | |
try: | |
with utils.MatplotlibBackendMananger(): | |
output = await app.get_blocks().process_api( | |
fn_index=fn_index_inferred, | |
inputs=inputs, | |
request=gr_request, | |
state=session_state, | |
iterators=iterators, | |
session_hash=session_hash, | |
event_id=event_id, | |
event_data=event_data, | |
) | |
iterator = output.pop("iterator", None) | |
if hasattr(body, "session_hash"): | |
app.iterators[body.session_hash][fn_index] = iterator | |
if isinstance(output, Error): | |
raise output | |
except BaseException: | |
iterator = iterators.get(fn_index, None) | |
if iterator is not None: # close off any streams that are still open | |
run_id = id(iterator) | |
pending_streams: dict[int, list] = ( | |
app.get_blocks().pending_streams[session_hash].get(run_id, {}) | |
) | |
for stream in pending_streams.values(): | |
stream.append(None) | |
raise | |
if batch_in_single_out: | |
output["data"] = output["data"][0] | |
return output | |