Spaces:
Sleeping
Sleeping
File size: 2,359 Bytes
327982a f84c1a6 327982a f84c1a6 327982a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import json
from typing import Any, Dict, Union
import requests
from llama_cpp import json_schema_to_gbnf # Only used directly to convert the JSON schema to GBNF,
# The main interface is the HTTP server, not the library directly.
def llm_streaming(prompt:str, pydantic_model_class, return_pydantic_object=False) -> Union[str, Dict[str, Any]]:
schema = pydantic_model_class.model_json_schema()
# Optional example field from schema, is not needed for the grammar generation
if "example" in schema:
del schema["example"]
json_schema = json.dumps(schema)
grammar = json_schema_to_gbnf(json_schema)
payload = {
"stream": True,
"max_tokens": 1000,
"grammar": grammar,
"temperature": 1.0,
"messages": [
{
"role": "user",
"content": prompt
}
],
}
headers = {
"Content-Type": "application/json",
}
response = requests.post("http://localhost:5834/v1/chat/completions"
, headers=headers, json=payload, stream=True)
output_text = ""
for chunk in response.iter_lines():
if chunk:
chunk = chunk.decode("utf-8")
if chunk.startswith("data: "):
chunk = chunk.split("data: ")[1]
if chunk.strip() == "[DONE]":
break
chunk = json.loads(chunk)
new_token = chunk.get('choices')[0].get('delta').get('content')
if new_token:
output_text = output_text + new_token
print(new_token,sep='',end='',flush=True)
#else:
# raise Exception(f"Parse error, expecting stream:{str(chunk)}")
if return_pydantic_object:
model_object = pydantic_model_class.model_validate_json(output_text)
return model_object
else:
json_output = json.loads(output_text)
return json_output
def replace_text(template: str, replacements: dict) -> str:
for key, value in replacements.items():
template = template.replace(f"{{{key}}}", value)
return template
def query_ai_prompt(prompt, replacements, model_class):
prompt = replace_text(prompt, replacements)
#print('prompt')
#print(prompt)
return llm_streaming(prompt, model_class)
|