Health/Assets/_VoiceAssistant/AIChatTookit/Scripts/LLM/Baidu/ChatBaidu.cs

232 lines
6.4 KiB
C#
Raw Normal View History

2023-11-21 08:57:37 +00:00
using Newtonsoft.Json;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.Networking;
public class ChatBaidu : LLM
{
public ChatBaidu()
{
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant";
}
void Awake()
{
OnInit();
}
/// <summary>
/// token脚本
/// </summary>
[SerializeField] private BaiduSettings m_Settings;
/// <summary>
/// 历史对话
/// </summary>
private List<message> m_History = new List<message>();
/// <summary>
/// 选择的模型类型
/// </summary>
[Header("设置模型名称")]
public ModelType m_ModelType = ModelType.ERNIE_Bot_turbo;
/// <summary>
/// 初始化
/// </summary>
private void OnInit()
{
m_Settings = this.GetComponent<BaiduSettings>();
GetEndPointUrl();
}
/// <summary>
/// 发送消息
/// </summary>
/// <returns></returns>
public override void PostMsg(string _msg, Action<string> _callback)
{
base.PostMsg(_msg, _callback);
}
/// <summary>
/// 发送数据
/// </summary>
/// <param name="_postWord"></param>
/// <param name="_callback"></param>
/// <returns></returns>
public override IEnumerator Request(string _postWord, System.Action<string> _callback)
{
stopwatch.Restart();
string _postUrl = url + "?access_token=" + m_Settings.m_Token;
m_History.Add(new message("user", _postWord));
RequestData _postData = new RequestData
{
messages = m_History
};
using (UnityWebRequest request = new UnityWebRequest(_postUrl, "POST"))
{
string _jsonData = JsonUtility.ToJson(_postData);
byte[] data = System.Text.Encoding.UTF8.GetBytes(_jsonData);
request.uploadHandler = (UploadHandler)new UploadHandlerRaw(data);
request.downloadHandler = (DownloadHandler)new DownloadHandlerBuffer();
request.SetRequestHeader("Content-Type", "application/json");
yield return request.SendWebRequest();
if (request.responseCode == 200)
{
string _msg = request.downloadHandler.text;
request.Dispose();
ResponseData response = JsonConvert.DeserializeObject<ResponseData>(_msg);
//历史记录
string _responseText = response.result;
m_History.Add(new message("assistant", response.result));
//添加记录
m_DataList.Add(new SendData("assistant", response.result));
//回调
_callback(response.result);
}
else
request.Dispose();
}
stopwatch.Stop();
Debug.Log("chat百度-耗时:" + stopwatch.Elapsed.TotalSeconds);
}
/// <summary>
/// 获取资源路径
/// </summary>
private void GetEndPointUrl()
{
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + CheckModelType(m_ModelType);
}
/// <summary>
/// 获取资源
/// </summary>
/// <param name="_type"></param>
/// <returns></returns>
private string CheckModelType(ModelType _type)
{
if (_type == ModelType.ERNIE_Bot){
return "completions";
}
if (_type == ModelType.ERNIE_Bot_turbo)
{
return "eb-instant";
}
if (_type == ModelType.BLOOMZ_7B)
{
return "bloomz_7b1";
}
if (_type == ModelType.Qianfan_BLOOMZ_7B_compressed)
{
return "qianfan_bloomz_7b_compressed";
}
if (_type == ModelType.ChatGLM2_6B_32K)
{
return "chatglm2_6b_32k";
}
if (_type == ModelType.Llama_2_7B_Chat)
{
return "llama_2_7b";
}
if (_type == ModelType.Llama_2_13B_Chat)
{
return "llama_2_13b";
}
if (_type == ModelType.Llama_2_70B_Chat)
{
return "llama_2_70b";
}
if (_type == ModelType.Qianfan_Chinese_Llama_2_7B)
{
return "qianfan_chinese_llama_2_7b";
}
if (_type == ModelType.AquilaChat_7B)
{
return "aquilachat_7b";
}
return "";
}
#region
//发送的数据
[Serializable]
private class RequestData
{
public List<message> messages = new List<message>();//发送的消息
public bool stream = false;//是否流式输出
public string user_id=string.Empty;
}
[Serializable]
private class message
{
public string role=string.Empty;//角色
public string content = string.Empty;//对话内容
public message() { }
public message(string _role,string _content)
{
role = _role;
content = _content;
}
}
//接收的数据
[Serializable]
private class ResponseData
{
public string id = string.Empty;//本轮对话的id
public int created;
public int sentence_id;//表示当前子句的序号。只有在流式接口模式下会返回该字段
public bool is_end;//表示当前子句是否是最后一句。只有在流式接口模式下会返回该字段
public bool is_truncated;//表示当前子句是否是最后一句。只有在流式接口模式下会返回该字段
public string result = string.Empty;//返回的文本
public bool need_clear_history;//表示用户输入是否存在安全
public int ban_round;//当need_clear_history为true时此字段会告知第几轮对话有敏感信息如果是当前问题ban_round=-1
public Usage usage = new Usage();//token统计信息token数 = 汉字数+单词数*1.3
}
[Serializable]
private class Usage
{
public int prompt_tokens;//问题tokens数
public int completion_tokens;//回答tokens数
public int total_tokens;//tokens总数
}
#endregion
/// <summary>
/// 模型名称
/// </summary>
public enum ModelType
{
ERNIE_Bot,
ERNIE_Bot_turbo,
BLOOMZ_7B,
Qianfan_BLOOMZ_7B_compressed,
ChatGLM2_6B_32K,
Llama_2_7B_Chat,
Llama_2_13B_Chat,
Llama_2_70B_Chat,
Qianfan_Chinese_Llama_2_7B,
AquilaChat_7B,
}
}