330 lines
9.9 KiB
C#
330 lines
9.9 KiB
C#
|
using System;
|
|||
|
using System.Collections;
|
|||
|
using System.Collections.Generic;
|
|||
|
using System.Net.WebSockets;
|
|||
|
using System.Text;
|
|||
|
using System.Threading;
|
|||
|
using UnityEngine;
|
|||
|
|
|||
|
public class ChatSpark : LLM
|
|||
|
{
|
|||
|
|
|||
|
#region 参数
|
|||
|
/// <summary>
|
|||
|
/// 讯飞的应用设置
|
|||
|
/// </summary>
|
|||
|
[SerializeField] private XunfeiSettings m_XunfeiSettings;
|
|||
|
/// <summary>
|
|||
|
/// 选择星火大模型版本
|
|||
|
/// </summary>
|
|||
|
[Header("选择星火大模型版本")]
|
|||
|
[SerializeField] private ModelType m_SparkModel = ModelType.星火大模型V15;
|
|||
|
|
|||
|
|
|||
|
#endregion
|
|||
|
|
|||
|
private void Awake()
|
|||
|
{
|
|||
|
OnInit();
|
|||
|
}
|
|||
|
/// <summary>
|
|||
|
/// 初始化
|
|||
|
/// </summary>
|
|||
|
private void OnInit()
|
|||
|
{
|
|||
|
m_XunfeiSettings = this.GetComponent<XunfeiSettings>();
|
|||
|
if (m_SparkModel == ModelType.星火大模型V15)
|
|||
|
{
|
|||
|
url= "https://spark-api.xf-yun.com/v1.1/chat";
|
|||
|
return;
|
|||
|
}
|
|||
|
|
|||
|
url = "https://spark-api.xf-yun.com/v2.1/chat";
|
|||
|
}
|
|||
|
|
|||
|
/// <summary>
|
|||
|
/// 发送消息
|
|||
|
/// </summary>
|
|||
|
/// <returns></returns>
|
|||
|
public override void PostMsg(string _msg, Action<string> _callback)
|
|||
|
{
|
|||
|
base.PostMsg(_msg, _callback);
|
|||
|
}
|
|||
|
|
|||
|
/// <summary>
|
|||
|
/// 发送数据
|
|||
|
/// </summary>
|
|||
|
/// <param name="_postWord"></param>
|
|||
|
/// <param name="_callback"></param>
|
|||
|
/// <returns></returns>
|
|||
|
public override IEnumerator Request(string _postWord, System.Action<string> _callback)
|
|||
|
{
|
|||
|
yield return null;
|
|||
|
//处理发送数据
|
|||
|
RequestData requestData = new RequestData();
|
|||
|
requestData.header.app_id = m_XunfeiSettings.m_AppID;
|
|||
|
|
|||
|
//判断v1.5还是v2
|
|||
|
requestData.parameter.chat.domain = GetDomain();
|
|||
|
|
|||
|
//添加对话列表
|
|||
|
List<PostMsgData> _tempList = new List<PostMsgData>();
|
|||
|
for(int i=0;i<m_DataList.Count;i++)
|
|||
|
{
|
|||
|
PostMsgData _msg = new PostMsgData()
|
|||
|
{
|
|||
|
role = m_DataList[i].role,
|
|||
|
content = m_DataList[i].content
|
|||
|
};
|
|||
|
_tempList.Add(_msg);
|
|||
|
}
|
|||
|
|
|||
|
requestData.payload.message.text= _tempList;
|
|||
|
|
|||
|
string _json = JsonUtility.ToJson(requestData);
|
|||
|
|
|||
|
//websocket连接
|
|||
|
ConnectHost(_json, _callback);
|
|||
|
|
|||
|
}
|
|||
|
|
|||
|
/// <summary>
|
|||
|
/// 指定访问的领域
|
|||
|
/// general指向V1.5版本
|
|||
|
/// generalv2指向V2版本
|
|||
|
/// </summary>
|
|||
|
/// <returns></returns>
|
|||
|
private string GetDomain()
|
|||
|
{
|
|||
|
if (m_SparkModel == ModelType.星火大模型V15)
|
|||
|
return "general";
|
|||
|
|
|||
|
return "generalv2";
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
#region 获取鉴权Url
|
|||
|
|
|||
|
/// <summary>
|
|||
|
/// 获取鉴权url
|
|||
|
/// </summary>
|
|||
|
/// <returns></returns>
|
|||
|
private string GetAuthUrl()
|
|||
|
{
|
|||
|
string date = DateTime.UtcNow.ToString("r");
|
|||
|
|
|||
|
Uri uri = new Uri(url);
|
|||
|
StringBuilder builder = new StringBuilder("host: ").Append(uri.Host).Append("\n").//
|
|||
|
Append("date: ").Append(date).Append("\n").//
|
|||
|
Append("GET ").Append(uri.LocalPath).Append(" HTTP/1.1");
|
|||
|
|
|||
|
string sha = HMACsha256(m_XunfeiSettings.m_APISecret, builder.ToString());
|
|||
|
string authorization = string.Format("api_key=\"{0}\", algorithm=\"{1}\", headers=\"{2}\", signature=\"{3}\"", m_XunfeiSettings.m_APIKey, "hmac-sha256", "host date request-line", sha);
|
|||
|
|
|||
|
string NewUrl = "https://" + uri.Host + uri.LocalPath;
|
|||
|
|
|||
|
string path1 = "authorization" + "=" + Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(authorization));
|
|||
|
date = date.Replace(" ", "%20").Replace(":", "%3A").Replace(",", "%2C");
|
|||
|
string path2 = "date" + "=" + date;
|
|||
|
string path3 = "host" + "=" + uri.Host;
|
|||
|
|
|||
|
NewUrl = NewUrl + "?" + path1 + "&" + path2 + "&" + path3;
|
|||
|
return NewUrl;
|
|||
|
}
|
|||
|
|
|||
|
public string HMACsha256(string apiSecretIsKey, string buider)
|
|||
|
{
|
|||
|
byte[] bytes = System.Text.Encoding.UTF8.GetBytes(apiSecretIsKey);
|
|||
|
System.Security.Cryptography.HMACSHA256 hMACSHA256 = new System.Security.Cryptography.HMACSHA256(bytes);
|
|||
|
byte[] date = System.Text.Encoding.UTF8.GetBytes(buider);
|
|||
|
date = hMACSHA256.ComputeHash(date);
|
|||
|
hMACSHA256.Clear();
|
|||
|
|
|||
|
return Convert.ToBase64String(date);
|
|||
|
|
|||
|
}
|
|||
|
|
|||
|
#endregion
|
|||
|
|
|||
|
#region websocket连接
|
|||
|
/// <summary>
|
|||
|
/// websocket
|
|||
|
/// </summary>
|
|||
|
private ClientWebSocket m_WebSocket;
|
|||
|
private CancellationToken m_CancellationToken;
|
|||
|
/// <summary>
|
|||
|
/// 连接服务器,获取回复
|
|||
|
/// </summary>
|
|||
|
private async void ConnectHost(string text,Action<string> _callback)
|
|||
|
{
|
|||
|
try
|
|||
|
{
|
|||
|
stopwatch.Restart();
|
|||
|
|
|||
|
m_WebSocket = new ClientWebSocket();
|
|||
|
m_CancellationToken = new CancellationToken();
|
|||
|
string authUrl = GetAuthUrl();
|
|||
|
string url = authUrl.Replace("http://", "ws://").Replace("https://", "wss://");
|
|||
|
|
|||
|
//Uri uri = new Uri(GetUrl());
|
|||
|
Uri uri = new Uri(url);
|
|||
|
await m_WebSocket.ConnectAsync(uri, m_CancellationToken);
|
|||
|
|
|||
|
//发送json
|
|||
|
string _jsonData = text;
|
|||
|
await m_WebSocket.SendAsync(new ArraySegment<byte>(Encoding.UTF8.GetBytes(_jsonData)), WebSocketMessageType.Binary, true, m_CancellationToken); //发送数据
|
|||
|
StringBuilder sb = new StringBuilder();
|
|||
|
//用于拼接返回的答复
|
|||
|
string _callBackMessage = "";
|
|||
|
|
|||
|
//播放队列.Clear();
|
|||
|
while (m_WebSocket.State == WebSocketState.Open)
|
|||
|
{
|
|||
|
var result = new byte[4096];
|
|||
|
await m_WebSocket.ReceiveAsync(new ArraySegment<byte>(result), m_CancellationToken);//接受数据
|
|||
|
List<byte> list = new List<byte>(result); while (list[list.Count - 1] == 0x00) list.RemoveAt(list.Count - 1);//去除空字节
|
|||
|
var str = Encoding.UTF8.GetString(list.ToArray());
|
|||
|
sb.Append(str);
|
|||
|
if (str.EndsWith("}"))
|
|||
|
{
|
|||
|
//获取返回的数据
|
|||
|
ResponseData _responseData = JsonUtility.FromJson<ResponseData>(sb.ToString());
|
|||
|
sb.Clear();
|
|||
|
|
|||
|
if (_responseData.header.code != 0)
|
|||
|
{
|
|||
|
//返回错误
|
|||
|
//PrintErrorLog(_responseData.code);
|
|||
|
Debug.Log("错误码:" + _responseData.header.code);
|
|||
|
m_WebSocket.Abort();
|
|||
|
break;
|
|||
|
}
|
|||
|
//没有回复数据
|
|||
|
if (_responseData.payload.choices.text.Count == 0)
|
|||
|
{
|
|||
|
Debug.LogError("没有获取到回复的信息!");
|
|||
|
m_WebSocket.Abort();
|
|||
|
break;
|
|||
|
}
|
|||
|
//拼接回复的数据
|
|||
|
_callBackMessage += _responseData.payload.choices.text[0].content;
|
|||
|
|
|||
|
if (_responseData.payload.choices.status == 2)
|
|||
|
{
|
|||
|
stopwatch.Stop();
|
|||
|
Debug.Log("ChatSpark耗时:" + stopwatch.Elapsed.TotalSeconds);
|
|||
|
|
|||
|
//添加记录
|
|||
|
m_DataList.Add(new SendData("assistant", _callBackMessage));
|
|||
|
|
|||
|
//回调
|
|||
|
_callback(_callBackMessage);
|
|||
|
m_WebSocket.Abort();
|
|||
|
break;
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
}
|
|||
|
catch (Exception ex)
|
|||
|
{
|
|||
|
Debug.LogError("报错信息: " + ex.Message);
|
|||
|
m_WebSocket.Dispose();
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
#endregion
|
|||
|
|
|||
|
#region 数据定义
|
|||
|
|
|||
|
//发送的数据
|
|||
|
[Serializable]
|
|||
|
private class RequestData
|
|||
|
{
|
|||
|
public HeaderData header=new HeaderData();
|
|||
|
public ParameterData parameter = new ParameterData();
|
|||
|
public MessageData payload = new MessageData();
|
|||
|
}
|
|||
|
[Serializable]
|
|||
|
private class HeaderData
|
|||
|
{
|
|||
|
public string app_id = string.Empty;//必填
|
|||
|
public string uid="admin";//选填,用户的ID
|
|||
|
}
|
|||
|
[Serializable]
|
|||
|
private class ParameterData
|
|||
|
{
|
|||
|
public ChatParameter chat=new ChatParameter();
|
|||
|
}
|
|||
|
[Serializable]
|
|||
|
private class ChatParameter
|
|||
|
{
|
|||
|
public string domain = "general";
|
|||
|
public float temperature = 0.5f;
|
|||
|
public int max_tokens = 1024;
|
|||
|
}
|
|||
|
|
|||
|
[Serializable]
|
|||
|
private class MessageData
|
|||
|
{
|
|||
|
public TextData message=new TextData();
|
|||
|
}
|
|||
|
[Serializable]
|
|||
|
private class TextData
|
|||
|
{
|
|||
|
public List<PostMsgData> text = new List<PostMsgData>();
|
|||
|
}
|
|||
|
[Serializable]
|
|||
|
private class PostMsgData
|
|||
|
{
|
|||
|
public string role = string.Empty;
|
|||
|
public string content = string.Empty;
|
|||
|
}
|
|||
|
|
|||
|
//接收的数据
|
|||
|
[Serializable]
|
|||
|
private class ResponseData
|
|||
|
{
|
|||
|
public ReHeaderData header = new ReHeaderData();
|
|||
|
public PayloadData payload = new PayloadData();
|
|||
|
}
|
|||
|
|
|||
|
[Serializable]
|
|||
|
private class ReHeaderData{
|
|||
|
public int code;//错误码,0表示正常,非0表示出错
|
|||
|
public string message=string.Empty;//会话是否成功的描述信息
|
|||
|
public string sid=string.Empty;
|
|||
|
public int status;//会话状态,取值为[0,1,2];0代表首次结果;1代表中间结果;2代表最后一个结果
|
|||
|
}
|
|||
|
[Serializable]
|
|||
|
private class PayloadData
|
|||
|
{
|
|||
|
public ChoicesData choices = new ChoicesData();
|
|||
|
//usage 暂时没用,需要的话自行拓展
|
|||
|
}
|
|||
|
[Serializable]
|
|||
|
private class ChoicesData
|
|||
|
{
|
|||
|
public int status;
|
|||
|
public int seq;
|
|||
|
public List<ReTextData> text = new List<ReTextData>();
|
|||
|
}
|
|||
|
[Serializable]
|
|||
|
private class ReTextData
|
|||
|
{
|
|||
|
public string content = string.Empty;
|
|||
|
public string role = string.Empty;
|
|||
|
public int index;
|
|||
|
}
|
|||
|
|
|||
|
private enum ModelType
|
|||
|
{
|
|||
|
星火大模型V15,
|
|||
|
星火大模型V20
|
|||
|
}
|
|||
|
|
|||
|
#endregion
|
|||
|
|
|||
|
|
|||
|
}
|