Etash Guha
commited on
Commit
·
32bc229
1
Parent(s):
8d320a4
pease
Browse files- generators/model.py +2 -3
generators/model.py
CHANGED
|
@@ -123,6 +123,7 @@ class Samba():
|
|
| 123 |
|
| 124 |
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
|
| 125 |
resps = []
|
|
|
|
| 126 |
for i in range(num_comps):
|
| 127 |
payload = {
|
| 128 |
"inputs": [dataclasses.asdict(message) for message in messages],
|
|
@@ -145,19 +146,17 @@ class Samba():
|
|
| 145 |
"Content-Type": "application/json"
|
| 146 |
}
|
| 147 |
post_response = requests.post(url, json=payload, headers=headers, stream=True)
|
| 148 |
-
|
| 149 |
response_text = ""
|
| 150 |
for line in post_response.iter_lines():
|
| 151 |
if line.startswith(b"data: "):
|
| 152 |
data_str = line.decode('utf-8')[6:]
|
| 153 |
try:
|
| 154 |
line_json = json.loads(data_str)
|
| 155 |
-
content = line_json.get("stream_token", "")
|
| 156 |
if content:
|
| 157 |
response_text += content
|
| 158 |
except json.JSONDecodeError as e:
|
| 159 |
pass
|
| 160 |
-
resps.append(response_text)
|
| 161 |
|
| 162 |
if num_comps == 1:
|
| 163 |
return resps[0]
|
|
|
|
| 123 |
|
| 124 |
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
|
| 125 |
resps = []
|
| 126 |
+
|
| 127 |
for i in range(num_comps):
|
| 128 |
payload = {
|
| 129 |
"inputs": [dataclasses.asdict(message) for message in messages],
|
|
|
|
| 146 |
"Content-Type": "application/json"
|
| 147 |
}
|
| 148 |
post_response = requests.post(url, json=payload, headers=headers, stream=True)
|
|
|
|
| 149 |
response_text = ""
|
| 150 |
for line in post_response.iter_lines():
|
| 151 |
if line.startswith(b"data: "):
|
| 152 |
data_str = line.decode('utf-8')[6:]
|
| 153 |
try:
|
| 154 |
line_json = json.loads(data_str)
|
| 155 |
+
content = line_json['0'].get("stream_token", "")
|
| 156 |
if content:
|
| 157 |
response_text += content
|
| 158 |
except json.JSONDecodeError as e:
|
| 159 |
pass
|
|
|
|
| 160 |
|
| 161 |
if num_comps == 1:
|
| 162 |
return resps[0]
|