Health/Assets/_VoiceAssistant/AIChatTookit/Scripts/LLM/SparkAI/ChatSpark.cs

330 lines
9.9 KiB
C#
Raw Normal View History

2023-11-21 08:57:37 +00:00
using System;
using System.Collections;
using System.Collections.Generic;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using UnityEngine;
public class ChatSpark : LLM
{
#region
/// <summary>
/// 讯飞的应用设置
/// </summary>
[SerializeField] private XunfeiSettings m_XunfeiSettings;
/// <summary>
/// 选择星火大模型版本
/// </summary>
[Header("选择星火大模型版本")]
[SerializeField] private ModelType m_SparkModel = ModelType.V15;
#endregion
private void Awake()
{
OnInit();
}
/// <summary>
/// 初始化
/// </summary>
private void OnInit()
{
m_XunfeiSettings = this.GetComponent<XunfeiSettings>();
if (m_SparkModel == ModelType.V15)
{
url= "https://spark-api.xf-yun.com/v1.1/chat";
return;
}
url = "https://spark-api.xf-yun.com/v2.1/chat";
}
/// <summary>
/// 发送消息
/// </summary>
/// <returns></returns>
public override void PostMsg(string _msg, Action<string> _callback)
{
base.PostMsg(_msg, _callback);
}
/// <summary>
/// 发送数据
/// </summary>
/// <param name="_postWord"></param>
/// <param name="_callback"></param>
/// <returns></returns>
public override IEnumerator Request(string _postWord, System.Action<string> _callback)
{
yield return null;
//处理发送数据
RequestData requestData = new RequestData();
requestData.header.app_id = m_XunfeiSettings.m_AppID;
//判断v1.5还是v2
requestData.parameter.chat.domain = GetDomain();
//添加对话列表
List<PostMsgData> _tempList = new List<PostMsgData>();
for(int i=0;i<m_DataList.Count;i++)
{
PostMsgData _msg = new PostMsgData()
{
role = m_DataList[i].role,
content = m_DataList[i].content
};
_tempList.Add(_msg);
}
requestData.payload.message.text= _tempList;
string _json = JsonUtility.ToJson(requestData);
//websocket连接
ConnectHost(_json, _callback);
}
/// <summary>
/// 指定访问的领域
/// general指向V1.5版本
/// generalv2指向V2版本
/// </summary>
/// <returns></returns>
private string GetDomain()
{
if (m_SparkModel == ModelType.V15)
return "general";
return "generalv2";
}
#region Url
/// <summary>
/// 获取鉴权url
/// </summary>
/// <returns></returns>
private string GetAuthUrl()
{
string date = DateTime.UtcNow.ToString("r");
Uri uri = new Uri(url);
StringBuilder builder = new StringBuilder("host: ").Append(uri.Host).Append("\n").//
Append("date: ").Append(date).Append("\n").//
Append("GET ").Append(uri.LocalPath).Append(" HTTP/1.1");
string sha = HMACsha256(m_XunfeiSettings.m_APISecret, builder.ToString());
string authorization = string.Format("api_key=\"{0}\", algorithm=\"{1}\", headers=\"{2}\", signature=\"{3}\"", m_XunfeiSettings.m_APIKey, "hmac-sha256", "host date request-line", sha);
string NewUrl = "https://" + uri.Host + uri.LocalPath;
string path1 = "authorization" + "=" + Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(authorization));
date = date.Replace(" ", "%20").Replace(":", "%3A").Replace(",", "%2C");
string path2 = "date" + "=" + date;
string path3 = "host" + "=" + uri.Host;
NewUrl = NewUrl + "?" + path1 + "&" + path2 + "&" + path3;
return NewUrl;
}
public string HMACsha256(string apiSecretIsKey, string buider)
{
byte[] bytes = System.Text.Encoding.UTF8.GetBytes(apiSecretIsKey);
System.Security.Cryptography.HMACSHA256 hMACSHA256 = new System.Security.Cryptography.HMACSHA256(bytes);
byte[] date = System.Text.Encoding.UTF8.GetBytes(buider);
date = hMACSHA256.ComputeHash(date);
hMACSHA256.Clear();
return Convert.ToBase64String(date);
}
#endregion
#region websocket连接
/// <summary>
/// websocket
/// </summary>
private ClientWebSocket m_WebSocket;
private CancellationToken m_CancellationToken;
/// <summary>
/// 连接服务器,获取回复
/// </summary>
private async void ConnectHost(string text,Action<string> _callback)
{
try
{
stopwatch.Restart();
m_WebSocket = new ClientWebSocket();
m_CancellationToken = new CancellationToken();
string authUrl = GetAuthUrl();
string url = authUrl.Replace("http://", "ws://").Replace("https://", "wss://");
//Uri uri = new Uri(GetUrl());
Uri uri = new Uri(url);
await m_WebSocket.ConnectAsync(uri, m_CancellationToken);
//发送json
string _jsonData = text;
await m_WebSocket.SendAsync(new ArraySegment<byte>(Encoding.UTF8.GetBytes(_jsonData)), WebSocketMessageType.Binary, true, m_CancellationToken); //发送数据
StringBuilder sb = new StringBuilder();
//用于拼接返回的答复
string _callBackMessage = "";
//播放队列.Clear();
while (m_WebSocket.State == WebSocketState.Open)
{
var result = new byte[4096];
await m_WebSocket.ReceiveAsync(new ArraySegment<byte>(result), m_CancellationToken);//接受数据
List<byte> list = new List<byte>(result); while (list[list.Count - 1] == 0x00) list.RemoveAt(list.Count - 1);//去除空字节
var str = Encoding.UTF8.GetString(list.ToArray());
sb.Append(str);
if (str.EndsWith("}"))
{
//获取返回的数据
ResponseData _responseData = JsonUtility.FromJson<ResponseData>(sb.ToString());
sb.Clear();
if (_responseData.header.code != 0)
{
//返回错误
//PrintErrorLog(_responseData.code);
Debug.Log("错误码:" + _responseData.header.code);
m_WebSocket.Abort();
break;
}
//没有回复数据
if (_responseData.payload.choices.text.Count == 0)
{
Debug.LogError("没有获取到回复的信息!");
m_WebSocket.Abort();
break;
}
//拼接回复的数据
_callBackMessage += _responseData.payload.choices.text[0].content;
if (_responseData.payload.choices.status == 2)
{
stopwatch.Stop();
Debug.Log("ChatSpark耗时" + stopwatch.Elapsed.TotalSeconds);
//添加记录
m_DataList.Add(new SendData("assistant", _callBackMessage));
//回调
_callback(_callBackMessage);
m_WebSocket.Abort();
break;
}
}
}
}
catch (Exception ex)
{
Debug.LogError("报错信息: " + ex.Message);
m_WebSocket.Dispose();
}
}
#endregion
#region
//发送的数据
[Serializable]
private class RequestData
{
public HeaderData header=new HeaderData();
public ParameterData parameter = new ParameterData();
public MessageData payload = new MessageData();
}
[Serializable]
private class HeaderData
{
public string app_id = string.Empty;//必填
public string uid="admin";//选填用户的ID
}
[Serializable]
private class ParameterData
{
public ChatParameter chat=new ChatParameter();
}
[Serializable]
private class ChatParameter
{
public string domain = "general";
public float temperature = 0.5f;
public int max_tokens = 1024;
}
[Serializable]
private class MessageData
{
public TextData message=new TextData();
}
[Serializable]
private class TextData
{
public List<PostMsgData> text = new List<PostMsgData>();
}
[Serializable]
private class PostMsgData
{
public string role = string.Empty;
public string content = string.Empty;
}
//接收的数据
[Serializable]
private class ResponseData
{
public ReHeaderData header = new ReHeaderData();
public PayloadData payload = new PayloadData();
}
[Serializable]
private class ReHeaderData{
public int code;//错误码0表示正常非0表示出错
public string message=string.Empty;//会话是否成功的描述信息
public string sid=string.Empty;
public int status;//会话状态,取值为[0,1,2]0代表首次结果1代表中间结果2代表最后一个结果
}
[Serializable]
private class PayloadData
{
public ChoicesData choices = new ChoicesData();
//usage 暂时没用,需要的话自行拓展
}
[Serializable]
private class ChoicesData
{
public int status;
public int seq;
public List<ReTextData> text = new List<ReTextData>();
}
[Serializable]
private class ReTextData
{
public string content = string.Empty;
public string role = string.Empty;
public int index;
}
private enum ModelType
{
V15,
V20
}
#endregion
}